diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index 731a610abf9..58b0dd00bc9 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -204,6 +204,7 @@ class Stream: self._thread_quit = threading.Event() self._outputs: dict[str, StreamOutput] = {} self._fast_restart_once = False + self._available = True def endpoint_url(self, fmt: str) -> str: """Start the stream and returns a url for the output format.""" @@ -254,6 +255,11 @@ class Stream: if all(p.idle for p in self._outputs.values()): self.access_token = None + @property + def available(self) -> bool: + """Return False if the stream is started and known to be unavailable.""" + return self._available + def start(self) -> None: """Start a stream.""" if self._thread is None or not self._thread.is_alive(): @@ -280,18 +286,25 @@ class Stream: """Handle consuming streams and restart keepalive streams.""" # Keep import here so that we can import stream integration without installing reqs # pylint: disable=import-outside-toplevel - from .worker import SegmentBuffer, stream_worker + from .worker import SegmentBuffer, StreamWorkerError, stream_worker segment_buffer = SegmentBuffer(self.hass, self.outputs) wait_timeout = 0 while not self._thread_quit.wait(timeout=wait_timeout): start_time = time.time() - stream_worker( - self.source, - self.options, - segment_buffer, - self._thread_quit, - ) + + self._available = True + try: + stream_worker( + self.source, + self.options, + segment_buffer, + self._thread_quit, + ) + except StreamWorkerError as err: + _LOGGER.error("Error from stream worker: %s", str(err)) + self._available = False + segment_buffer.discontinuity() if not self.keepalive or self._thread_quit.is_set(): if self._fast_restart_once: @@ -300,6 +313,7 @@ class Stream: self._thread_quit.clear() continue break + # To avoid excessive restarts, wait before restarting # As the required recovery time may be different for different setups, start # with trying a short wait_timeout and increase it on each reconnection attempt. diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index a0ab48290f5..5176b93dedf 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import defaultdict, deque from collections.abc import Callable, Generator, Iterator, Mapping +import contextlib import datetime from io import BytesIO import logging @@ -31,6 +32,14 @@ from .hls import HlsStreamOutput _LOGGER = logging.getLogger(__name__) +class StreamWorkerError(Exception): + """An exception thrown while processing a stream.""" + + +class StreamEndedError(StreamWorkerError): + """Raised when the stream is complete, exposed for facilitating testing.""" + + class SegmentBuffer: """Buffer for writing a sequence of packets to the output as a segment.""" @@ -356,7 +365,7 @@ class TimestampValidator: # Discard packets missing DTS. Terminate if too many are missing. if packet.dts is None: if self._missing_dts >= MAX_MISSING_DTS: - raise StopIteration( + raise StreamWorkerError( f"No dts in {MAX_MISSING_DTS+1} consecutive packets" ) self._missing_dts += 1 @@ -367,7 +376,7 @@ class TimestampValidator: if packet.dts <= prev_dts: gap = packet.time_base * (prev_dts - packet.dts) if gap > MAX_TIMESTAMP_GAP: - raise StopIteration( + raise StreamWorkerError( f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}" ) return False @@ -410,15 +419,14 @@ def stream_worker( try: container = av.open(source, options=options, timeout=SOURCE_TIMEOUT) - except av.AVError: - _LOGGER.error("Error opening stream %s", redact_credentials(str(source))) - return + except av.AVError as err: + raise StreamWorkerError( + "Error opening stream %s" % redact_credentials(str(source)) + ) from err try: video_stream = container.streams.video[0] - except (KeyError, IndexError): - _LOGGER.error("Stream has no video") - container.close() - return + except (KeyError, IndexError) as ex: + raise StreamWorkerError("Stream has no video") from ex try: audio_stream = container.streams.audio[0] except (KeyError, IndexError): @@ -469,10 +477,17 @@ def stream_worker( # dts. Use "or 1" to deal with this. start_dts = next_video_packet.dts - (next_video_packet.duration or 1) first_keyframe.dts = first_keyframe.pts = start_dts - except (av.AVError, StopIteration) as ex: - _LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex)) + except StreamWorkerError as ex: container.close() - return + raise ex + except StopIteration as ex: + container.close() + raise StreamEndedError("Stream ended; no additional packets") from ex + except av.AVError as ex: + container.close() + raise StreamWorkerError( + "Error demuxing stream while finding first packet: %s" % str(ex) + ) from ex segment_buffer.set_streams(video_stream, audio_stream) segment_buffer.reset(start_dts) @@ -480,14 +495,15 @@ def stream_worker( # Mux the first keyframe, then proceed through the rest of the packets segment_buffer.mux_packet(first_keyframe) - while not quit_event.is_set(): - try: - packet = next(container_packets) - except (av.AVError, StopIteration) as ex: - _LOGGER.error("Error demuxing stream: %s", str(ex)) - break - segment_buffer.mux_packet(packet) + with contextlib.closing(container), contextlib.closing(segment_buffer): + while not quit_event.is_set(): + try: + packet = next(container_packets) + except StreamWorkerError as ex: + raise ex + except StopIteration as ex: + raise StreamEndedError("Stream ended; no additional packets") from ex + except av.AVError as ex: + raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex - # Close stream - segment_buffer.close() - container.close() + segment_buffer.mux_packet(packet) diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 3bff13a936b..9c529d7abe5 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -135,6 +135,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync): # Request stream stream.add_provider(HLS_PROVIDER) + assert stream.available stream.start() hls_client = await hls_stream(stream) @@ -161,6 +162,9 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync): stream_worker_sync.resume() + # The stream worker reported end of stream and exited + assert not stream.available + # Stop stream, if it hasn't quit already stream.stop() @@ -181,6 +185,7 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync): # Request stream stream.add_provider(HLS_PROVIDER) + assert stream.available stream.start() url = stream.endpoint_url(HLS_PROVIDER) @@ -267,6 +272,7 @@ async def test_stream_keepalive(hass): stream._thread.join() stream._thread = None assert av_open.call_count == 2 + assert not stream.available # Stop stream, if it hasn't quit already stream.stop() diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index 97fe4bd0d37..c65e10d65f3 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -37,7 +37,12 @@ from homeassistant.components.stream.const import ( TARGET_SEGMENT_DURATION_NON_LL_HLS, ) from homeassistant.components.stream.core import StreamSettings -from homeassistant.components.stream.worker import SegmentBuffer, stream_worker +from homeassistant.components.stream.worker import ( + SegmentBuffer, + StreamEndedError, + StreamWorkerError, + stream_worker, +) from homeassistant.setup import async_setup_component from tests.components.stream.common import generate_h264_video, generate_h265_video @@ -264,8 +269,15 @@ async def async_decode_stream(hass, packets, py_av=None): side_effect=py_av.capture_buffer.capture_output_segment, ): segment_buffer = SegmentBuffer(hass, stream.outputs) - stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) - await hass.async_block_till_done() + try: + stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) + except StreamEndedError: + # Tests only use a limited number of packets, then the worker exits as expected. In + # production, stream ending would be unexpected. + pass + finally: + # Wait for all packets to be flushed even when exceptions are thrown + await hass.async_block_till_done() return py_av.capture_buffer @@ -274,7 +286,7 @@ async def test_stream_open_fails(hass): """Test failure on stream open.""" stream = Stream(hass, STREAM_SOURCE, {}) stream.add_provider(HLS_PROVIDER) - with patch("av.open") as av_open: + with patch("av.open") as av_open, pytest.raises(StreamWorkerError): av_open.side_effect = av.error.InvalidDataError(-2, "error") segment_buffer = SegmentBuffer(hass, stream.outputs) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) @@ -371,7 +383,10 @@ async def test_packet_overflow(hass): # Packet is so far out of order, exceeds max gap and looks like overflow packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 - decoded_stream = await async_decode_stream(hass, packets) + py_av = MockPyAv() + with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"): + await async_decode_stream(hass, packets, py_av=py_av) + decoded_stream = py_av.capture_buffer segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check number of segments @@ -425,7 +440,10 @@ async def test_too_many_initial_bad_packets_fails(hass): for i in range(0, num_bad_packets): packets[i].dts = None - decoded_stream = await async_decode_stream(hass, packets) + py_av = MockPyAv() + with pytest.raises(StreamWorkerError, match=r"No dts"): + await async_decode_stream(hass, packets, py_av=py_av) + decoded_stream = py_av.capture_buffer segments = decoded_stream.segments assert len(segments) == 0 assert len(decoded_stream.video_packets) == 0 @@ -466,7 +484,10 @@ async def test_too_many_bad_packets(hass): for i in range(bad_packet_start, bad_packet_start + num_bad_packets): packets[i].dts = None - decoded_stream = await async_decode_stream(hass, packets) + py_av = MockPyAv() + with pytest.raises(StreamWorkerError, match=r"No dts"): + await async_decode_stream(hass, packets, py_av=py_av) + decoded_stream = py_av.capture_buffer complete_segments = decoded_stream.complete_segments assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == bad_packet_start @@ -477,9 +498,11 @@ async def test_no_video_stream(hass): """Test no video stream in the container means no resulting output.""" py_av = MockPyAv(video=False) - decoded_stream = await async_decode_stream( - hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av - ) + with pytest.raises(StreamWorkerError, match=r"Stream has no video"): + await async_decode_stream( + hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av + ) + decoded_stream = py_av.capture_buffer # Note: This failure scenario does not output an end of stream segments = decoded_stream.segments assert len(segments) == 0 @@ -616,6 +639,9 @@ async def test_stream_stopped_while_decoding(hass): worker_wake.set() stream.stop() + # Stream is still considered available when the worker was still active and asked to stop + assert stream.available + async def test_update_stream_source(hass): """Tests that the worker is re-invoked when the stream source is updated.""" @@ -646,6 +672,7 @@ async def test_update_stream_source(hass): stream.start() assert worker_open.wait(TIMEOUT) assert last_stream_source == STREAM_SOURCE + assert stream.available # Update the stream source, then the test wakes up the worker and assert # that it re-opens the new stream (the test again waits on thread_started) @@ -655,6 +682,7 @@ async def test_update_stream_source(hass): assert worker_open.wait(TIMEOUT) assert last_stream_source == STREAM_SOURCE + "-updated-source" worker_wake.set() + assert stream.available # Cleanup stream.stop() @@ -664,15 +692,16 @@ async def test_worker_log(hass, caplog): """Test that the worker logs the url without username and password.""" stream = Stream(hass, "https://abcd:efgh@foo.bar", {}) stream.add_provider(HLS_PROVIDER) - with patch("av.open") as av_open: + + with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err: av_open.side_effect = av.error.InvalidDataError(-2, "error") segment_buffer = SegmentBuffer(hass, stream.outputs) stream_worker( "https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event() ) await hass.async_block_till_done() + assert str(err.value) == "Error opening stream https://****:****@foo.bar" assert "https://abcd:efgh@foo.bar" not in caplog.text - assert "https://****:****@foo.bar" in caplog.text async def test_durations(hass, record_worker_sync):