From dca6a9389854380a27baea3b694320786fc34c68 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 8 Feb 2021 07:19:41 -0800 Subject: [PATCH] Centralize keepalive logic in Stream class (#45850) * Remove dependencies on keepalive from StremaOutput and stream_worker Pull logic from StreamOutput and stream_worker into the Stream class, unifying keepalive and idle timeout logic. This prepares for future changes to preserve hls state across stream url changes. --- homeassistant/components/stream/__init__.py | 68 ++++++++++++--- homeassistant/components/stream/const.py | 2 + homeassistant/components/stream/core.py | 92 +++++++++++++-------- homeassistant/components/stream/recorder.py | 16 ++-- homeassistant/components/stream/worker.py | 38 --------- tests/components/stream/conftest.py | 28 +++---- tests/components/stream/test_hls.py | 5 +- tests/components/stream/test_init.py | 4 +- tests/components/stream/test_recorder.py | 8 +- tests/components/stream/test_worker.py | 20 +---- 10 files changed, 142 insertions(+), 139 deletions(-) diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index c7d1dad4835..6980f7ead8f 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -2,6 +2,7 @@ import logging import secrets import threading +import time from types import MappingProxyType import voluptuous as vol @@ -20,9 +21,12 @@ from .const import ( CONF_STREAM_SOURCE, DOMAIN, MAX_SEGMENTS, + OUTPUT_IDLE_TIMEOUT, SERVICE_RECORD, + STREAM_RESTART_INCREMENT, + STREAM_RESTART_RESET_TIME, ) -from .core import PROVIDERS +from .core import PROVIDERS, IdleTimer from .hls import async_setup_hls _LOGGER = logging.getLogger(__name__) @@ -142,18 +146,27 @@ class Stream: # without concern about self._outputs being modified from another thread. return MappingProxyType(self._outputs.copy()) - def add_provider(self, fmt): + def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT): """Add provider output stream.""" if not self._outputs.get(fmt): - provider = PROVIDERS[fmt](self) + + @callback + def idle_callback(): + if not self.keepalive and fmt in self._outputs: + self.remove_provider(self._outputs[fmt]) + self.check_idle() + + provider = PROVIDERS[fmt]( + self.hass, IdleTimer(self.hass, timeout, idle_callback) + ) self._outputs[fmt] = provider return self._outputs[fmt] def remove_provider(self, provider): """Remove provider output stream.""" if provider.name in self._outputs: + self._outputs[provider.name].cleanup() del self._outputs[provider.name] - self.check_idle() if not self._outputs: self.stop() @@ -165,10 +178,6 @@ class Stream: def start(self): """Start a stream.""" - # Keep import here so that we can import stream integration without installing reqs - # pylint: disable=import-outside-toplevel - from .worker import stream_worker - if self._thread is None or not self._thread.is_alive(): if self._thread is not None: # The thread must have crashed/exited. Join to clean up the @@ -177,12 +186,48 @@ class Stream: self._thread_quit = threading.Event() self._thread = threading.Thread( name="stream_worker", - target=stream_worker, - args=(self.hass, self, self._thread_quit), + target=self._run_worker, ) self._thread.start() _LOGGER.info("Started stream: %s", self.source) + def _run_worker(self): + """Handle consuming streams and restart keepalive streams.""" + # Keep import here so that we can import stream integration without installing reqs + # pylint: disable=import-outside-toplevel + from .worker import stream_worker + + wait_timeout = 0 + while not self._thread_quit.wait(timeout=wait_timeout): + start_time = time.time() + stream_worker(self.hass, self, self._thread_quit) + if not self.keepalive or self._thread_quit.is_set(): + break + + # To avoid excessive restarts, wait before restarting + # As the required recovery time may be different for different setups, start + # with trying a short wait_timeout and increase it on each reconnection attempt. + # Reset the wait_timeout after the worker has been up for several minutes + if time.time() - start_time > STREAM_RESTART_RESET_TIME: + wait_timeout = 0 + wait_timeout += STREAM_RESTART_INCREMENT + _LOGGER.debug( + "Restarting stream worker in %d seconds: %s", + wait_timeout, + self.source, + ) + self._worker_finished() + + def _worker_finished(self): + """Schedule cleanup of all outputs.""" + + @callback + def remove_outputs(): + for provider in self.outputs.values(): + self.remove_provider(provider) + + self.hass.loop.call_soon_threadsafe(remove_outputs) + def stop(self): """Remove outputs and access token.""" self._outputs = {} @@ -223,9 +268,8 @@ async def async_handle_record_service(hass, call): if recorder: raise HomeAssistantError(f"Stream already recording to {recorder.video_path}!") - recorder = stream.add_provider("recorder") + recorder = stream.add_provider("recorder", timeout=duration) recorder.video_path = video_path - recorder.timeout = duration stream.start() diff --git a/homeassistant/components/stream/const.py b/homeassistant/components/stream/const.py index 181808e549e..45fa3d9e76a 100644 --- a/homeassistant/components/stream/const.py +++ b/homeassistant/components/stream/const.py @@ -15,6 +15,8 @@ OUTPUT_FORMATS = ["hls"] FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"} +OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity + MAX_SEGMENTS = 3 # Max number of segments to keep around MIN_SEGMENT_DURATION = 1.5 # Each segment is at least this many seconds diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 5158ba185b1..5427172a55c 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -8,7 +8,7 @@ from aiohttp import web import attr from homeassistant.components.http import HomeAssistantView -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.event import async_call_later from homeassistant.util.decorator import Registry @@ -36,24 +36,69 @@ class Segment: duration: float = attr.ib() +class IdleTimer: + """Invoke a callback after an inactivity timeout. + + The IdleTimer invokes the callback after some timeout has passed. The awake() method + resets the internal alarm, extending the inactivity time. + """ + + def __init__( + self, hass: HomeAssistant, timeout: int, idle_callback: Callable[[], None] + ): + """Initialize IdleTimer.""" + self._hass = hass + self._timeout = timeout + self._callback = idle_callback + self._unsub = None + self.idle = False + + def start(self): + """Start the idle timer if not already started.""" + self.idle = False + if self._unsub is None: + self._unsub = async_call_later(self._hass, self._timeout, self.fire) + + def awake(self): + """Keep the idle time alive by resetting the timeout.""" + self.idle = False + # Reset idle timeout + self.clear() + self._unsub = async_call_later(self._hass, self._timeout, self.fire) + + def clear(self): + """Clear and disable the timer.""" + if self._unsub is not None: + self._unsub() + + def fire(self, _now=None): + """Invoke the idle timeout callback, called when the alarm fires.""" + self.idle = True + self._unsub = None + self._callback() + + class StreamOutput: """Represents a stream output.""" - def __init__(self, stream, timeout: int = 300) -> None: + def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: """Initialize a stream output.""" - self.idle = False - self.timeout = timeout - self._stream = stream + self._hass = hass + self._idle_timer = idle_timer self._cursor = None self._event = asyncio.Event() self._segments = deque(maxlen=MAX_SEGMENTS) - self._unsub = None @property def name(self) -> str: """Return provider name.""" return None + @property + def idle(self) -> bool: + """Return True if the output is idle.""" + return self._idle_timer.idle + @property def format(self) -> str: """Return container format.""" @@ -90,11 +135,7 @@ class StreamOutput: def get_segment(self, sequence: int = None) -> Any: """Retrieve a specific segment, or the whole list.""" - self.idle = False - # Reset idle timeout - if self._unsub is not None: - self._unsub() - self._unsub = async_call_later(self._stream.hass, self.timeout, self._timeout) + self._idle_timer.awake() if not sequence: return self._segments @@ -119,43 +160,22 @@ class StreamOutput: def put(self, segment: Segment) -> None: """Store output.""" - self._stream.hass.loop.call_soon_threadsafe(self._async_put, segment) + self._hass.loop.call_soon_threadsafe(self._async_put, segment) @callback def _async_put(self, segment: Segment) -> None: """Store output from event loop.""" # Start idle timeout when we start receiving data - if self._unsub is None: - self._unsub = async_call_later( - self._stream.hass, self.timeout, self._timeout - ) - - if segment is None: - self._event.set() - # Cleanup provider - if self._unsub is not None: - self._unsub() - self.cleanup() - return - + self._idle_timer.start() self._segments.append(segment) self._event.set() self._event.clear() - @callback - def _timeout(self, _now=None): - """Handle stream timeout.""" - self._unsub = None - if self._stream.keepalive: - self.idle = True - self._stream.check_idle() - else: - self.cleanup() - def cleanup(self): """Handle cleanup.""" + self._event.set() + self._idle_timer.clear() self._segments = deque(maxlen=MAX_SEGMENTS) - self._stream.remove_provider(self) class StreamView(HomeAssistantView): diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py index cf923de85c2..7db9997f870 100644 --- a/homeassistant/components/stream/recorder.py +++ b/homeassistant/components/stream/recorder.py @@ -6,9 +6,9 @@ from typing import List import av -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback -from .core import PROVIDERS, Segment, StreamOutput +from .core import PROVIDERS, IdleTimer, Segment, StreamOutput _LOGGER = logging.getLogger(__name__) @@ -72,9 +72,9 @@ def recorder_save_worker(file_out: str, segments: List[Segment], container_forma class RecorderOutput(StreamOutput): """Represents HLS Output formats.""" - def __init__(self, stream, timeout: int = 30) -> None: + def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: """Initialize recorder output.""" - super().__init__(stream, timeout) + super().__init__(hass, idle_timer) self.video_path = None self._segments = [] @@ -104,12 +104,6 @@ class RecorderOutput(StreamOutput): segments = [s for s in segments if s.sequence not in own_segments] self._segments = segments + self._segments - @callback - def _timeout(self, _now=None): - """Handle recorder timeout.""" - self._unsub = None - self.cleanup() - def cleanup(self): """Write recording and clean up.""" _LOGGER.debug("Starting recorder worker thread") @@ -120,5 +114,5 @@ class RecorderOutput(StreamOutput): ) thread.start() + super().cleanup() self._segments = [] - self._stream.remove_provider(self) diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index cccbfd1b48b..510d0ebd460 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -2,7 +2,6 @@ from collections import deque import io import logging -import time import av @@ -11,8 +10,6 @@ from .const import ( MAX_TIMESTAMP_GAP, MIN_SEGMENT_DURATION, PACKETS_TO_WAIT_FOR_AUDIO, - STREAM_RESTART_INCREMENT, - STREAM_RESTART_RESET_TIME, STREAM_TIMEOUT, ) from .core import Segment, StreamBuffer @@ -47,32 +44,6 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence): def stream_worker(hass, stream, quit_event): - """Handle consuming streams and restart keepalive streams.""" - - wait_timeout = 0 - while not quit_event.wait(timeout=wait_timeout): - start_time = time.time() - try: - _stream_worker_internal(hass, stream, quit_event) - except av.error.FFmpegError: # pylint: disable=c-extension-no-member - _LOGGER.exception("Stream connection failed: %s", stream.source) - if not stream.keepalive or quit_event.is_set(): - break - # To avoid excessive restarts, wait before restarting - # As the required recovery time may be different for different setups, start - # with trying a short wait_timeout and increase it on each reconnection attempt. - # Reset the wait_timeout after the worker has been up for several minutes - if time.time() - start_time > STREAM_RESTART_RESET_TIME: - wait_timeout = 0 - wait_timeout += STREAM_RESTART_INCREMENT - _LOGGER.debug( - "Restarting stream worker in %d seconds: %s", - wait_timeout, - stream.source, - ) - - -def _stream_worker_internal(hass, stream, quit_event): """Handle consuming streams.""" try: @@ -183,7 +154,6 @@ def _stream_worker_internal(hass, stream, quit_event): _LOGGER.error( "Error demuxing stream while finding first packet: %s", str(ex) ) - finalize_stream() return False return True @@ -220,12 +190,6 @@ def _stream_worker_internal(hass, stream, quit_event): packet.stream = output_streams[audio_stream] buffer.output.mux(packet) - def finalize_stream(): - if not stream.keepalive: - # End of stream, clear listeners and stop thread - for fmt in stream.outputs: - stream.outputs[fmt].put(None) - if not peek_first_pts(): container.close() return @@ -249,7 +213,6 @@ def _stream_worker_internal(hass, stream, quit_event): missing_dts = 0 except (av.AVError, StopIteration) as ex: _LOGGER.error("Error demuxing stream: %s", str(ex)) - finalize_stream() break # Discard packet if dts is not monotonic @@ -263,7 +226,6 @@ def _stream_worker_internal(hass, stream, quit_event): last_dts[packet.stream], packet.dts, ) - finalize_stream() break continue diff --git a/tests/components/stream/conftest.py b/tests/components/stream/conftest.py index 1b2f0645f9b..75ac9377b7c 100644 --- a/tests/components/stream/conftest.py +++ b/tests/components/stream/conftest.py @@ -9,14 +9,13 @@ nothing for the test to verify. The solution is the WorkerSync class that allows the tests to pause the worker thread before finalizing the stream so that it can inspect the output. """ - import logging import threading from unittest.mock import patch import pytest -from homeassistant.components.stream.core import Segment, StreamOutput +from homeassistant.components.stream import Stream class WorkerSync: @@ -25,7 +24,7 @@ class WorkerSync: def __init__(self): """Initialize WorkerSync.""" self._event = None - self._put_original = StreamOutput.put + self._original = Stream._worker_finished def pause(self): """Pause the worker before it finalizes the stream.""" @@ -35,17 +34,16 @@ class WorkerSync: """Allow the worker thread to finalize the stream.""" self._event.set() - def blocking_put(self, stream_output: StreamOutput, segment: Segment): - """Proxy StreamOutput.put, intercepted for test to pause worker.""" - if segment is None and self._event: - # Worker is ending the stream, which clears all output buffers. - # Block the worker thread until the test has a chance to verify - # the segments under test. - logging.error("blocking worker") - self._event.wait() + def blocking_finish(self, stream: Stream): + """Intercept call to pause stream worker.""" + # Worker is ending the stream, which clears all output buffers. + # Block the worker thread until the test has a chance to verify + # the segments under test. + logging.debug("blocking worker") + self._event.wait() - # Forward to actual StreamOutput.put - self._put_original(stream_output, segment) + # Forward to actual implementation + self._original(stream) @pytest.fixture() @@ -53,8 +51,8 @@ def stream_worker_sync(hass): """Patch StreamOutput to allow test to synchronize worker stream end.""" sync = WorkerSync() with patch( - "homeassistant.components.stream.core.StreamOutput.put", - side_effect=sync.blocking_put, + "homeassistant.components.stream.Stream._worker_finished", + side_effect=sync.blocking_finish, autospec=True, ): yield sync diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 790222b1630..ab49a56ca02 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -98,6 +98,7 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync): # Wait 5 minutes future = dt_util.utcnow() + timedelta(minutes=5) async_fire_time_changed(hass, future) + await hass.async_block_till_done() # Ensure playlist not accessible fail_response = await http_client.get(parsed_url.path) @@ -155,9 +156,9 @@ async def test_stream_keepalive(hass): return cur_time with patch("av.open") as av_open, patch( - "homeassistant.components.stream.worker.time" + "homeassistant.components.stream.time" ) as mock_time, patch( - "homeassistant.components.stream.worker.STREAM_RESTART_INCREMENT", 0 + "homeassistant.components.stream.STREAM_RESTART_INCREMENT", 0 ): av_open.side_effect = av.error.InvalidDataError(-2, "error") mock_time.time.side_effect = time_side_effect diff --git a/tests/components/stream/test_init.py b/tests/components/stream/test_init.py index 1515ff1a490..2e13493b641 100644 --- a/tests/components/stream/test_init.py +++ b/tests/components/stream/test_init.py @@ -80,5 +80,7 @@ async def test_record_service_lookback(hass): await hass.services.async_call(DOMAIN, SERVICE_RECORD, data, blocking=True) assert stream_mock.called - stream_mock.return_value.add_provider.assert_called_once_with("recorder") + stream_mock.return_value.add_provider.assert_called_once_with( + "recorder", timeout=30 + ) assert hls_mock.recv.called diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index 1b46738c8f2..bda53a9cc17 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -106,13 +106,11 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync): stream_worker_sync.pause() - with patch( - "homeassistant.components.stream.recorder.RecorderOutput.cleanup" - ) as mock_cleanup: + with patch("homeassistant.components.stream.IdleTimer.fire") as mock_timeout: # Setup demo track source = generate_h264_video() stream = preload_stream(hass, source) - recorder = stream.add_provider("recorder") + recorder = stream.add_provider("recorder", timeout=30) stream.start() await recorder.recv() @@ -122,7 +120,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync): async_fire_time_changed(hass, future) await hass.async_block_till_done() - assert mock_cleanup.called + assert mock_timeout.called stream_worker_sync.resume() stream.stop() diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index 8196899dcf9..91d02664d74 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -132,7 +132,6 @@ class FakePyAvBuffer: self.segments = [] self.audio_packets = [] self.video_packets = [] - self.finished = False def add_stream(self, template=None): """Create an output buffer that captures packets for test to examine.""" @@ -162,11 +161,7 @@ class FakePyAvBuffer: def capture_output_segment(self, segment): """Capture the output segment for tests to inspect.""" - assert not self.finished - if segment is None: - self.finished = True - else: - self.segments.append(segment) + self.segments.append(segment) class MockPyAv: @@ -223,7 +218,6 @@ async def test_stream_worker_success(hass): decoded_stream = await async_decode_stream( hass, PacketSequence(TEST_SEQUENCE_LENGTH) ) - assert decoded_stream.finished segments = decoded_stream.segments # Check number of segments. A segment is only formed when a packet from the next # segment arrives, hence the subtraction of one from the sequence length. @@ -243,7 +237,6 @@ async def test_skip_out_of_order_packet(hass): packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9090 decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments # Check sequence numbers assert all([segments[i].sequence == i + 1 for i in range(len(segments))]) @@ -279,7 +272,6 @@ async def test_discard_old_packets(hass): packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090 decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments # Check number of segments assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) @@ -299,7 +291,6 @@ async def test_packet_overflow(hass): packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments # Check number of segments assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) @@ -321,7 +312,6 @@ async def test_skip_initial_bad_packets(hass): packets[i].dts = None decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments # Check number of segments assert len(segments) == int( @@ -345,7 +335,6 @@ async def test_too_many_initial_bad_packets_fails(hass): packets[i].dts = None decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments assert len(segments) == 0 assert len(decoded_stream.video_packets) == 0 @@ -363,7 +352,6 @@ async def test_skip_missing_dts(hass): packets[i].dts = None decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments # Check sequence numbers assert all([segments[i].sequence == i + 1 for i in range(len(segments))]) @@ -387,7 +375,6 @@ async def test_too_many_bad_packets(hass): packets[i].dts = None decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == bad_packet_start @@ -402,7 +389,6 @@ async def test_no_video_stream(hass): hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av ) # Note: This failure scenario does not output an end of stream - assert not decoded_stream.finished segments = decoded_stream.segments assert len(segments) == 0 assert len(decoded_stream.video_packets) == 0 @@ -417,7 +403,6 @@ async def test_audio_packets_not_found(hass): packets = PacketSequence(num_packets) # Contains only video packets decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) - assert decoded_stream.finished segments = decoded_stream.segments assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == num_packets @@ -439,7 +424,6 @@ async def test_audio_is_first_packet(hass): packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) - assert decoded_stream.finished segments = decoded_stream.segments # The audio packets are segmented with the video packets assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) @@ -458,7 +442,6 @@ async def test_audio_packets_found(hass): packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) - assert decoded_stream.finished segments = decoded_stream.segments # The audio packet above is buffered with the video packet assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) @@ -477,7 +460,6 @@ async def test_pts_out_of_order(hass): packets[i].is_keyframe = False decoded_stream = await async_decode_stream(hass, iter(packets)) - assert decoded_stream.finished segments = decoded_stream.segments # Check number of segments assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET)