Add check for client errors to stream component (#132866)

This commit is contained in:
Steven B. 2024-12-20 17:20:24 +00:00 committed by GitHub
parent 233395c181
commit 6ed345f773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 190 additions and 1 deletions

View File

@ -20,6 +20,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Mapping
import copy
from enum import IntEnum
import logging
import secrets
import threading
@ -45,6 +46,7 @@ from .const import (
CONF_EXTRA_PART_WAIT_TIME,
CONF_LL_HLS,
CONF_PART_DURATION,
CONF_PREFER_TCP,
CONF_RTSP_TRANSPORT,
CONF_SEGMENT_DURATION,
CONF_USE_WALLCLOCK_AS_TIMESTAMPS,
@ -74,6 +76,8 @@ from .diagnostics import Diagnostics
from .hls import HlsStreamOutput, async_setup_hls
if TYPE_CHECKING:
from av.container import InputContainer, OutputContainer
from homeassistant.components.camera import DynamicStreamSettings
__all__ = [
@ -95,6 +99,113 @@ __all__ = [
_LOGGER = logging.getLogger(__name__)
class StreamClientError(IntEnum):
"""Enum for stream client errors."""
BadRequest = 400
Unauthorized = 401
Forbidden = 403
NotFound = 404
Other = 4
class StreamOpenClientError(HomeAssistantError):
"""Raised when client error received when trying to open a stream.
:param stream_client_error: The type of client error
"""
def __init__(
self, *args: Any, stream_client_error: StreamClientError, **kwargs: Any
) -> None:
self.stream_client_error = stream_client_error
super().__init__(*args, **kwargs)
async def _async_try_open_stream(
hass: HomeAssistant, source: str, pyav_options: dict[str, str] | None = None
) -> InputContainer | OutputContainer:
"""Try to open a stream.
Will raise StreamOpenClientError if an http client error is encountered.
"""
return await hass.loop.run_in_executor(None, _try_open_stream, source, pyav_options)
def _try_open_stream(
source: str, pyav_options: dict[str, str] | None = None
) -> InputContainer | OutputContainer:
"""Try to open a stream.
Will raise StreamOpenClientError if an http client error is encountered.
"""
import av # pylint: disable=import-outside-toplevel
if pyav_options is None:
pyav_options = {}
default_pyav_options = {
"rtsp_flags": CONF_PREFER_TCP,
"timeout": str(SOURCE_TIMEOUT),
}
pyav_options = {
**default_pyav_options,
**pyav_options,
}
try:
container = av.open(source, options=pyav_options, timeout=5)
except av.HTTPBadRequestError as ex:
raise StreamOpenClientError(
stream_client_error=StreamClientError.BadRequest
) from ex
except av.HTTPUnauthorizedError as ex:
raise StreamOpenClientError(
stream_client_error=StreamClientError.Unauthorized
) from ex
except av.HTTPForbiddenError as ex:
raise StreamOpenClientError(
stream_client_error=StreamClientError.Forbidden
) from ex
except av.HTTPNotFoundError as ex:
raise StreamOpenClientError(
stream_client_error=StreamClientError.NotFound
) from ex
except av.HTTPOtherClientError as ex:
raise StreamOpenClientError(stream_client_error=StreamClientError.Other) from ex
else:
return container
async def async_check_stream_client_error(
hass: HomeAssistant, source: str, pyav_options: dict[str, str] | None = None
) -> None:
"""Check if a stream can be successfully opened.
Raise StreamOpenClientError if an http client error is encountered.
"""
await hass.loop.run_in_executor(
None, _check_stream_client_error, source, pyav_options
)
def _check_stream_client_error(
source: str, pyav_options: dict[str, str] | None = None
) -> None:
"""Check if a stream can be successfully opened.
Raise StreamOpenClientError if an http client error is encountered.
"""
_try_open_stream(source, pyav_options).close()
def redact_credentials(url: str) -> str:
"""Redact credentials from string data."""
yurl = URL(url)

View File

@ -1,11 +1,20 @@
"""Test stream init."""
import logging
from unittest.mock import MagicMock, patch
import av
import pytest
from homeassistant.components.stream import __name__ as stream_name
from homeassistant.components.stream import (
CONF_PREFER_TCP,
SOURCE_TIMEOUT,
StreamClientError,
StreamOpenClientError,
__name__ as stream_name,
_async_try_open_stream,
async_check_stream_client_error,
)
from homeassistant.const import EVENT_LOGGING_CHANGED
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -53,3 +62,72 @@ async def test_log_levels(
assert "SHOULD PASS" in caplog.text
assert "SHOULD NOT PASS" not in caplog.text
async def test_check_open_stream_params(hass: HomeAssistant) -> None:
"""Test check open stream params."""
container_mock = MagicMock()
source = "rtsp://foobar"
with patch("av.open", return_value=container_mock) as open_mock:
await async_check_stream_client_error(hass, source)
options = {
"rtsp_flags": CONF_PREFER_TCP,
"timeout": str(SOURCE_TIMEOUT),
}
open_mock.assert_called_once_with(source, options=options, timeout=5)
container_mock.close.assert_called_once()
container_mock.reset_mock()
with patch("av.open", return_value=container_mock) as open_mock:
await async_check_stream_client_error(hass, source, {"foo": "bar"})
options = {
"rtsp_flags": CONF_PREFER_TCP,
"timeout": str(SOURCE_TIMEOUT),
"foo": "bar",
}
open_mock.assert_called_once_with(source, options=options, timeout=5)
container_mock.close.assert_called_once()
@pytest.mark.parametrize(
("error", "enum_result"),
[
pytest.param(
av.HTTPBadRequestError(400, ""),
StreamClientError.BadRequest,
id="BadRequest",
),
pytest.param(
av.HTTPUnauthorizedError(401, ""),
StreamClientError.Unauthorized,
id="Unauthorized",
),
pytest.param(
av.HTTPForbiddenError(403, ""), StreamClientError.Forbidden, id="Forbidden"
),
pytest.param(
av.HTTPNotFoundError(404, ""), StreamClientError.NotFound, id="NotFound"
),
pytest.param(
av.HTTPOtherClientError(408, ""), StreamClientError.Other, id="Other"
),
],
)
async def test_try_open_stream_error(
hass: HomeAssistant, error: av.HTTPClientError, enum_result: StreamClientError
) -> None:
"""Test trying to open a stream."""
oc_error: StreamOpenClientError | None = None
with patch("av.open", side_effect=error):
try:
await _async_try_open_stream(hass, "rtsp://foobar")
except StreamOpenClientError as ex:
oc_error = ex
assert oc_error
assert oc_error.stream_client_error is enum_result