From 6ed345f7732889d5a5d4f5dde246c3069592374d Mon Sep 17 00:00:00 2001 From: "Steven B." <51370195+sdb9696@users.noreply.github.com> Date: Fri, 20 Dec 2024 17:20:24 +0000 Subject: [PATCH] Add check for client errors to stream component (#132866) --- homeassistant/components/stream/__init__.py | 111 ++++++++++++++++++++ tests/components/stream/test_init.py | 80 +++++++++++++- 2 files changed, 190 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index 64c520150c2..1471db890d7 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -20,6 +20,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Mapping import copy +from enum import IntEnum import logging import secrets import threading @@ -45,6 +46,7 @@ from .const import ( CONF_EXTRA_PART_WAIT_TIME, CONF_LL_HLS, CONF_PART_DURATION, + CONF_PREFER_TCP, CONF_RTSP_TRANSPORT, CONF_SEGMENT_DURATION, CONF_USE_WALLCLOCK_AS_TIMESTAMPS, @@ -74,6 +76,8 @@ from .diagnostics import Diagnostics from .hls import HlsStreamOutput, async_setup_hls if TYPE_CHECKING: + from av.container import InputContainer, OutputContainer + from homeassistant.components.camera import DynamicStreamSettings __all__ = [ @@ -95,6 +99,113 @@ __all__ = [ _LOGGER = logging.getLogger(__name__) +class StreamClientError(IntEnum): + """Enum for stream client errors.""" + + BadRequest = 400 + Unauthorized = 401 + Forbidden = 403 + NotFound = 404 + Other = 4 + + +class StreamOpenClientError(HomeAssistantError): + """Raised when client error received when trying to open a stream. + + :param stream_client_error: The type of client error + """ + + def __init__( + self, *args: Any, stream_client_error: StreamClientError, **kwargs: Any + ) -> None: + self.stream_client_error = stream_client_error + super().__init__(*args, **kwargs) + + +async def _async_try_open_stream( + hass: HomeAssistant, source: str, pyav_options: dict[str, str] | None = None +) -> InputContainer | OutputContainer: + """Try to open a stream. + + Will raise StreamOpenClientError if an http client error is encountered. + """ + return await hass.loop.run_in_executor(None, _try_open_stream, source, pyav_options) + + +def _try_open_stream( + source: str, pyav_options: dict[str, str] | None = None +) -> InputContainer | OutputContainer: + """Try to open a stream. + + Will raise StreamOpenClientError if an http client error is encountered. + """ + import av # pylint: disable=import-outside-toplevel + + if pyav_options is None: + pyav_options = {} + + default_pyav_options = { + "rtsp_flags": CONF_PREFER_TCP, + "timeout": str(SOURCE_TIMEOUT), + } + + pyav_options = { + **default_pyav_options, + **pyav_options, + } + + try: + container = av.open(source, options=pyav_options, timeout=5) + + except av.HTTPBadRequestError as ex: + raise StreamOpenClientError( + stream_client_error=StreamClientError.BadRequest + ) from ex + + except av.HTTPUnauthorizedError as ex: + raise StreamOpenClientError( + stream_client_error=StreamClientError.Unauthorized + ) from ex + + except av.HTTPForbiddenError as ex: + raise StreamOpenClientError( + stream_client_error=StreamClientError.Forbidden + ) from ex + + except av.HTTPNotFoundError as ex: + raise StreamOpenClientError( + stream_client_error=StreamClientError.NotFound + ) from ex + + except av.HTTPOtherClientError as ex: + raise StreamOpenClientError(stream_client_error=StreamClientError.Other) from ex + + else: + return container + + +async def async_check_stream_client_error( + hass: HomeAssistant, source: str, pyav_options: dict[str, str] | None = None +) -> None: + """Check if a stream can be successfully opened. + + Raise StreamOpenClientError if an http client error is encountered. + """ + await hass.loop.run_in_executor( + None, _check_stream_client_error, source, pyav_options + ) + + +def _check_stream_client_error( + source: str, pyav_options: dict[str, str] | None = None +) -> None: + """Check if a stream can be successfully opened. + + Raise StreamOpenClientError if an http client error is encountered. + """ + _try_open_stream(source, pyav_options).close() + + def redact_credentials(url: str) -> str: """Redact credentials from string data.""" yurl = URL(url) diff --git a/tests/components/stream/test_init.py b/tests/components/stream/test_init.py index 1ae6f9e8931..5f9d305620d 100644 --- a/tests/components/stream/test_init.py +++ b/tests/components/stream/test_init.py @@ -1,11 +1,20 @@ """Test stream init.""" import logging +from unittest.mock import MagicMock, patch import av import pytest -from homeassistant.components.stream import __name__ as stream_name +from homeassistant.components.stream import ( + CONF_PREFER_TCP, + SOURCE_TIMEOUT, + StreamClientError, + StreamOpenClientError, + __name__ as stream_name, + _async_try_open_stream, + async_check_stream_client_error, +) from homeassistant.const import EVENT_LOGGING_CHANGED from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -53,3 +62,72 @@ async def test_log_levels( assert "SHOULD PASS" in caplog.text assert "SHOULD NOT PASS" not in caplog.text + + +async def test_check_open_stream_params(hass: HomeAssistant) -> None: + """Test check open stream params.""" + + container_mock = MagicMock() + source = "rtsp://foobar" + + with patch("av.open", return_value=container_mock) as open_mock: + await async_check_stream_client_error(hass, source) + + options = { + "rtsp_flags": CONF_PREFER_TCP, + "timeout": str(SOURCE_TIMEOUT), + } + open_mock.assert_called_once_with(source, options=options, timeout=5) + container_mock.close.assert_called_once() + + container_mock.reset_mock() + with patch("av.open", return_value=container_mock) as open_mock: + await async_check_stream_client_error(hass, source, {"foo": "bar"}) + + options = { + "rtsp_flags": CONF_PREFER_TCP, + "timeout": str(SOURCE_TIMEOUT), + "foo": "bar", + } + open_mock.assert_called_once_with(source, options=options, timeout=5) + container_mock.close.assert_called_once() + + +@pytest.mark.parametrize( + ("error", "enum_result"), + [ + pytest.param( + av.HTTPBadRequestError(400, ""), + StreamClientError.BadRequest, + id="BadRequest", + ), + pytest.param( + av.HTTPUnauthorizedError(401, ""), + StreamClientError.Unauthorized, + id="Unauthorized", + ), + pytest.param( + av.HTTPForbiddenError(403, ""), StreamClientError.Forbidden, id="Forbidden" + ), + pytest.param( + av.HTTPNotFoundError(404, ""), StreamClientError.NotFound, id="NotFound" + ), + pytest.param( + av.HTTPOtherClientError(408, ""), StreamClientError.Other, id="Other" + ), + ], +) +async def test_try_open_stream_error( + hass: HomeAssistant, error: av.HTTPClientError, enum_result: StreamClientError +) -> None: + """Test trying to open a stream.""" + oc_error: StreamOpenClientError | None = None + + with patch("av.open", side_effect=error): + try: + await _async_try_open_stream(hass, "rtsp://foobar") + except StreamOpenClientError as ex: + oc_error = ex + + assert oc_error + assert oc_error.stream_client_error is enum_result