diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index ba61f155473..d0bee4e249f 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -24,7 +24,11 @@ from homeassistant.components.media_player.const import ( SERVICE_PLAY_MEDIA, ) from homeassistant.components.stream import Stream, create_stream -from homeassistant.components.stream.const import FORMAT_CONTENT_TYPE, OUTPUT_FORMATS +from homeassistant.components.stream.const import ( + FORMAT_CONTENT_TYPE, + HLS_OUTPUT, + OUTPUT_FORMATS, +) from homeassistant.const import ( ATTR_ENTITY_ID, CONF_FILENAME, @@ -254,7 +258,7 @@ async def async_setup(hass, config): stream = await camera.create_stream() if not stream: continue - stream.add_provider("hls") + stream.hls_output() stream.start() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream) @@ -702,6 +706,8 @@ async def async_handle_play_stream_service(camera, service_call): async def _async_stream_endpoint_url(hass, camera, fmt): + if fmt != HLS_OUTPUT: + raise ValueError("Only format {HLS_OUTPUT} is supported") stream = await camera.create_stream() if not stream: raise HomeAssistantError( @@ -712,9 +718,9 @@ async def _async_stream_endpoint_url(hass, camera, fmt): camera_prefs = hass.data[DATA_CAMERA_PREFS].get(camera.entity_id) stream.keepalive = camera_prefs.preload_stream - stream.add_provider(fmt) + stream.hls_output() stream.start() - return stream.endpoint_url(fmt) + return stream.endpoint_url() async def async_handle_record_service(camera, call): diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index cdaa0faeb95..677d01e5006 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -7,25 +7,25 @@ a new Stream object. Stream manages: - Home Assistant URLs for viewing a stream - Access tokens for URLs for viewing a stream -A Stream consists of a background worker, and one or more output formats each -with their own idle timeout managed by the stream component. When an output -format is no longer in use, the stream component will expire it. When there -are no active output formats, the background worker is shut down and access -tokens are expired. Alternatively, a Stream can be configured with keepalive -to always keep workers active. +A Stream consists of a background worker and multiple output streams (e.g. hls +and recorder). The worker has a callback to retrieve the current active output +streams where it writes the decoded output packets. The HLS stream has an +inactivity idle timeout that expires the access token. When all output streams +are inactive, the background worker is shut down. Alternatively, a Stream +can be configured with keepalive to always keep workers active. """ import logging import secrets import threading import time -from types import MappingProxyType +from typing import List from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from .const import ( - ATTR_ENDPOINTS, + ATTR_HLS_ENDPOINT, ATTR_STREAMS, DOMAIN, MAX_SEGMENTS, @@ -33,8 +33,8 @@ from .const import ( STREAM_RESTART_INCREMENT, STREAM_RESTART_RESET_TIME, ) -from .core import PROVIDERS, IdleTimer -from .hls import async_setup_hls +from .core import IdleTimer, StreamOutput +from .hls import HlsStreamOutput, async_setup_hls _LOGGER = logging.getLogger(__name__) @@ -75,12 +75,10 @@ async def async_setup(hass, config): from .recorder import async_setup_recorder hass.data[DOMAIN] = {} - hass.data[DOMAIN][ATTR_ENDPOINTS] = {} hass.data[DOMAIN][ATTR_STREAMS] = [] # Setup HLS - hls_endpoint = async_setup_hls(hass) - hass.data[DOMAIN][ATTR_ENDPOINTS]["hls"] = hls_endpoint + hass.data[DOMAIN][ATTR_HLS_ENDPOINT] = async_setup_hls(hass) # Setup Recorder async_setup_recorder(hass) @@ -89,7 +87,6 @@ async def async_setup(hass, config): def shutdown(event): """Stop all stream workers.""" for stream in hass.data[DOMAIN][ATTR_STREAMS]: - stream.keepalive = False stream.stop() _LOGGER.info("Stopped stream workers") @@ -110,58 +107,53 @@ class Stream: self.access_token = None self._thread = None self._thread_quit = threading.Event() - self._outputs = {} + self._hls = None + self._hls_timer = None + self._recorder = None self._fast_restart_once = False if self.options is None: self.options = {} - def endpoint_url(self, fmt): - """Start the stream and returns a url for the output format.""" - if fmt not in self._outputs: - raise ValueError(f"Stream is not configured for format '{fmt}'") + def endpoint_url(self) -> str: + """Start the stream and returns a url for the hls endpoint.""" + if not self._hls: + raise ValueError("Stream is not configured for hls") if not self.access_token: self.access_token = secrets.token_hex() - return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token) + return self.hass.data[DOMAIN][ATTR_HLS_ENDPOINT].format(self.access_token) - def outputs(self): - """Return a copy of the stream outputs.""" - # A copy is returned so the caller can iterate through the outputs - # without concern about self._outputs being modified from another thread. - return MappingProxyType(self._outputs.copy()) + def outputs(self) -> List[StreamOutput]: + """Return the active stream outputs.""" + return [output for output in [self._hls, self._recorder] if output] - def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT): - """Add provider output stream.""" - if not self._outputs.get(fmt): + def hls_output(self) -> StreamOutput: + """Return the hls output stream, creating if not already active.""" + if not self._hls: + self._hls = HlsStreamOutput(self.hass) + self._hls_timer = IdleTimer(self.hass, OUTPUT_IDLE_TIMEOUT, self._hls_idle) + self._hls_timer.start() + self._hls_timer.awake() + return self._hls - @callback - def idle_callback(): - if not self.keepalive and fmt in self._outputs: - self.remove_provider(self._outputs[fmt]) - self.check_idle() + @callback + def _hls_idle(self): + """Reset access token and cleanup stream due to inactivity.""" + self.access_token = None + if not self.keepalive: + self._hls.cleanup() + self._hls = None + self._hls_timer = None + self._check_idle() - provider = PROVIDERS[fmt]( - self.hass, IdleTimer(self.hass, timeout, idle_callback) - ) - self._outputs[fmt] = provider - return self._outputs[fmt] - - def remove_provider(self, provider): - """Remove provider output stream.""" - if provider.name in self._outputs: - self._outputs[provider.name].cleanup() - del self._outputs[provider.name] - - if not self._outputs: - self.stop() - - def check_idle(self): - """Reset access token if all providers are idle.""" - if all([p.idle for p in self._outputs.values()]): - self.access_token = None + def _check_idle(self): + """Check if all outputs are idle and shut down worker.""" + if self.keepalive or self.outputs(): + return + self.stop() def start(self): - """Start a stream.""" + """Start stream decode worker.""" if self._thread is None or not self._thread.is_alive(): if self._thread is not None: # The thread must have crashed/exited. Join to clean up the @@ -215,21 +207,18 @@ class Stream: def _worker_finished(self): """Schedule cleanup of all outputs.""" - - @callback - def remove_outputs(): - for provider in self.outputs().values(): - self.remove_provider(provider) - - self.hass.loop.call_soon_threadsafe(remove_outputs) + self.hass.loop.call_soon_threadsafe(self.stop) def stop(self): """Remove outputs and access token.""" - self._outputs = {} self.access_token = None - - if not self.keepalive: - self._stop() + if self._hls: + self._hls.cleanup() + self._hls = None + if self._recorder: + self._recorder.save() + self._recorder = None + self._stop() def _stop(self): """Stop worker thread.""" @@ -242,25 +231,35 @@ class Stream: async def async_record(self, video_path, duration=30, lookback=5): """Make a .mp4 recording from a provided stream.""" + # Keep import here so that we can import stream integration without installing reqs + # pylint: disable=import-outside-toplevel + from .recorder import RecorderOutput + # Check for file access if not self.hass.config.is_allowed_path(video_path): raise HomeAssistantError(f"Can't write {video_path}, no access to path!") # Add recorder - recorder = self.outputs().get("recorder") - if recorder: + if self._recorder: raise HomeAssistantError( - f"Stream already recording to {recorder.video_path}!" + f"Stream already recording to {self._recorder.video_path}!" ) - recorder = self.add_provider("recorder", timeout=duration) - recorder.video_path = video_path - + self._recorder = RecorderOutput(self.hass) + self._recorder.video_path = video_path self.start() # Take advantage of lookback - hls = self.outputs().get("hls") - if lookback > 0 and hls: - num_segments = min(int(lookback // hls.target_duration), MAX_SEGMENTS) + if lookback > 0 and self._hls: + num_segments = min(int(lookback // self._hls.target_duration), MAX_SEGMENTS) # Wait for latest segment, then add the lookback - await hls.recv() - recorder.prepend(list(hls.get_segment())[-num_segments:]) + await self._hls.recv() + self._recorder.prepend(list(self._hls.get_segment())[-num_segments:]) + + @callback + def save_recording(): + if self._recorder: + self._recorder.save() + self._recorder = None + self._check_idle() + + IdleTimer(self.hass, duration, save_recording).start() diff --git a/homeassistant/components/stream/const.py b/homeassistant/components/stream/const.py index 41df806d020..55f447a9a69 100644 --- a/homeassistant/components/stream/const.py +++ b/homeassistant/components/stream/const.py @@ -1,10 +1,14 @@ """Constants for Stream component.""" DOMAIN = "stream" -ATTR_ENDPOINTS = "endpoints" +ATTR_HLS_ENDPOINT = "hls_endpoint" ATTR_STREAMS = "streams" -OUTPUT_FORMATS = ["hls"] +HLS_OUTPUT = "hls" +OUTPUT_FORMATS = [HLS_OUTPUT] +OUTPUT_CONTAINER_FORMAT = "mp4" +OUTPUT_VIDEO_CODECS = {"hevc", "h264"} +OUTPUT_AUDIO_CODECS = {"aac", "mp3"} FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"} diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 31c7940b8e1..4fc70eb856f 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -1,8 +1,7 @@ """Provides core stream functionality.""" -import asyncio -from collections import deque +import abc import io -from typing import Any, Callable, List +from typing import Callable from aiohttp import web import attr @@ -10,11 +9,8 @@ import attr from homeassistant.components.http import HomeAssistantView from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.event import async_call_later -from homeassistant.util.decorator import Registry -from .const import ATTR_STREAMS, DOMAIN, MAX_SEGMENTS - -PROVIDERS = Registry() +from .const import ATTR_STREAMS, DOMAIN @attr.s @@ -78,86 +74,18 @@ class IdleTimer: self._callback() -class StreamOutput: +class StreamOutput(abc.ABC): """Represents a stream output.""" - def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: + def __init__(self, hass: HomeAssistant): """Initialize a stream output.""" self._hass = hass - self._idle_timer = idle_timer - self._cursor = None - self._event = asyncio.Event() - self._segments = deque(maxlen=MAX_SEGMENTS) - - @property - def name(self) -> str: - """Return provider name.""" - return None - - @property - def idle(self) -> bool: - """Return True if the output is idle.""" - return self._idle_timer.idle - - @property - def format(self) -> str: - """Return container format.""" - return None - - @property - def audio_codecs(self) -> str: - """Return desired audio codecs.""" - return None - - @property - def video_codecs(self) -> tuple: - """Return desired video codecs.""" - return None @property def container_options(self) -> Callable[[int], dict]: """Return Callable which takes a sequence number and returns container options.""" return None - @property - def segments(self) -> List[int]: - """Return current sequence from segments.""" - return [s.sequence for s in self._segments] - - @property - def target_duration(self) -> int: - """Return the max duration of any given segment in seconds.""" - segment_length = len(self._segments) - if not segment_length: - return 1 - durations = [s.duration for s in self._segments] - return round(max(durations)) or 1 - - def get_segment(self, sequence: int = None) -> Any: - """Retrieve a specific segment, or the whole list.""" - self._idle_timer.awake() - - if not sequence: - return self._segments - - for segment in self._segments: - if segment.sequence == sequence: - return segment - return None - - async def recv(self) -> Segment: - """Wait for and retrieve the latest segment.""" - last_segment = max(self.segments, default=0) - if self._cursor is None or self._cursor <= last_segment: - await self._event.wait() - - if not self._segments: - return None - - segment = self.get_segment()[-1] - self._cursor = segment.sequence - return segment - def put(self, segment: Segment) -> None: """Store output.""" self._hass.loop.call_soon_threadsafe(self._async_put, segment) @@ -165,17 +93,6 @@ class StreamOutput: @callback def _async_put(self, segment: Segment) -> None: """Store output from event loop.""" - # Start idle timeout when we start receiving data - self._idle_timer.start() - self._segments.append(segment) - self._event.set() - self._event.clear() - - def cleanup(self): - """Handle cleanup.""" - self._event.set() - self._idle_timer.clear() - self._segments = deque(maxlen=MAX_SEGMENTS) class StreamView(HomeAssistantView): diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index bd5fbd5e9ae..57894d17711 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -1,13 +1,15 @@ """Provide functionality to stream HLS.""" +import asyncio +from collections import deque import io -from typing import Callable +from typing import Any, Callable, List from aiohttp import web from homeassistant.core import callback -from .const import FORMAT_CONTENT_TYPE, NUM_PLAYLIST_SEGMENTS -from .core import PROVIDERS, StreamOutput, StreamView +from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS +from .core import Segment, StreamOutput, StreamView from .fmp4utils import get_codec_string, get_init, get_m4s @@ -48,8 +50,7 @@ class HlsMasterPlaylistView(StreamView): async def handle(self, request, stream, sequence): """Return m3u8 playlist.""" - track = stream.add_provider("hls") - stream.start() + track = stream.hls_output() # Wait for a segment to be ready if not track.segments: if not await track.recv(): @@ -102,8 +103,7 @@ class HlsPlaylistView(StreamView): async def handle(self, request, stream, sequence): """Return m3u8 playlist.""" - track = stream.add_provider("hls") - stream.start() + track = stream.hls_output() # Wait for a segment to be ready if not track.segments: if not await track.recv(): @@ -121,7 +121,7 @@ class HlsInitView(StreamView): async def handle(self, request, stream, sequence): """Return init.mp4.""" - track = stream.add_provider("hls") + track = stream.hls_output() segments = track.get_segment() if not segments: return web.HTTPNotFound() @@ -138,7 +138,7 @@ class HlsSegmentView(StreamView): async def handle(self, request, stream, sequence): """Return fmp4 segment.""" - track = stream.add_provider("hls") + track = stream.hls_output() segment = track.get_segment(int(sequence)) if not segment: return web.HTTPNotFound() @@ -149,29 +149,15 @@ class HlsSegmentView(StreamView): ) -@PROVIDERS.register("hls") class HlsStreamOutput(StreamOutput): """Represents HLS Output formats.""" - @property - def name(self) -> str: - """Return provider name.""" - return "hls" - - @property - def format(self) -> str: - """Return container format.""" - return "mp4" - - @property - def audio_codecs(self) -> str: - """Return desired audio codecs.""" - return {"aac", "mp3"} - - @property - def video_codecs(self) -> tuple: - """Return desired video codecs.""" - return {"hevc", "h264"} + def __init__(self, hass) -> None: + """Initialize HlsStreamOutput.""" + super().__init__(hass) + self._cursor = None + self._event = asyncio.Event() + self._segments = deque(maxlen=MAX_SEGMENTS) @property def container_options(self) -> Callable[[int], dict]: @@ -182,3 +168,51 @@ class HlsStreamOutput(StreamOutput): "avoid_negative_ts": "make_non_negative", "fragment_index": str(sequence), } + + @property + def segments(self) -> List[int]: + """Return current sequence from segments.""" + return [s.sequence for s in self._segments] + + @property + def target_duration(self) -> int: + """Return the max duration of any given segment in seconds.""" + segment_length = len(self._segments) + if not segment_length: + return 1 + durations = [s.duration for s in self._segments] + return round(max(durations)) or 1 + + def get_segment(self, sequence: int = None) -> Any: + """Retrieve a specific segment, or the whole list.""" + if not sequence: + return self._segments + + for segment in self._segments: + if segment.sequence == sequence: + return segment + return None + + async def recv(self) -> Segment: + """Wait for and retrieve the latest segment.""" + last_segment = max(self.segments, default=0) + if self._cursor is None or self._cursor <= last_segment: + await self._event.wait() + + if not self._segments: + return None + + segment = self.get_segment()[-1] + self._cursor = segment.sequence + return segment + + def _async_put(self, segment: Segment) -> None: + """Store output from event loop.""" + self._segments.append(segment) + self._event.set() + self._event.clear() + + def cleanup(self): + """Handle cleanup.""" + self._event.set() + self._segments = deque(maxlen=MAX_SEGMENTS) diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py index 7db9997f870..0fc3d84b1b9 100644 --- a/homeassistant/components/stream/recorder.py +++ b/homeassistant/components/stream/recorder.py @@ -6,9 +6,10 @@ from typing import List import av -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import callback -from .core import PROVIDERS, IdleTimer, Segment, StreamOutput +from .const import OUTPUT_CONTAINER_FORMAT +from .core import Segment, StreamOutput _LOGGER = logging.getLogger(__name__) @@ -18,7 +19,7 @@ def async_setup_recorder(hass): """Only here so Provider Registry works.""" -def recorder_save_worker(file_out: str, segments: List[Segment], container_format: str): +def recorder_save_worker(file_out: str, segments: List[Segment], container_format): """Handle saving stream.""" if not os.path.exists(os.path.dirname(file_out)): os.makedirs(os.path.dirname(file_out), exist_ok=True) @@ -68,51 +69,31 @@ def recorder_save_worker(file_out: str, segments: List[Segment], container_forma output.close() -@PROVIDERS.register("recorder") class RecorderOutput(StreamOutput): """Represents HLS Output formats.""" - def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: + def __init__(self, hass) -> None: """Initialize recorder output.""" - super().__init__(hass, idle_timer) + super().__init__(hass) self.video_path = None self._segments = [] - @property - def name(self) -> str: - """Return provider name.""" - return "recorder" - - @property - def format(self) -> str: - """Return container format.""" - return "mp4" - - @property - def audio_codecs(self) -> str: - """Return desired audio codec.""" - return {"aac", "mp3"} - - @property - def video_codecs(self) -> tuple: - """Return desired video codecs.""" - return {"hevc", "h264"} + def _async_put(self, segment: Segment) -> None: + """Store output.""" + self._segments.append(segment) def prepend(self, segments: List[Segment]) -> None: """Prepend segments to existing list.""" - own_segments = self.segments - segments = [s for s in segments if s.sequence not in own_segments] + segments = [s for s in segments if s.sequence not in self._segments] self._segments = segments + self._segments - def cleanup(self): + def save(self): """Write recording and clean up.""" _LOGGER.debug("Starting recorder worker thread") thread = threading.Thread( name="recorder_save_worker", target=recorder_save_worker, - args=(self.video_path, self._segments, self.format), + args=(self.video_path, self._segments, OUTPUT_CONTAINER_FORMAT), ) thread.start() - - super().cleanup() self._segments = [] diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 2050787a714..41cb4bafd90 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -9,6 +9,9 @@ from .const import ( MAX_MISSING_DTS, MAX_TIMESTAMP_GAP, MIN_SEGMENT_DURATION, + OUTPUT_AUDIO_CODECS, + OUTPUT_CONTAINER_FORMAT, + OUTPUT_VIDEO_CODECS, PACKETS_TO_WAIT_FOR_AUDIO, STREAM_TIMEOUT, ) @@ -29,7 +32,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence): output = av.open( segment, mode="w", - format=stream_output.format, + format=OUTPUT_CONTAINER_FORMAT, container_options={ "video_track_timescale": str(int(1 / video_stream.time_base)), **container_options, @@ -38,7 +41,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence): vstream = output.add_stream(template=video_stream) # Check if audio is requested astream = None - if audio_stream and audio_stream.name in stream_output.audio_codecs: + if audio_stream and audio_stream.name in OUTPUT_AUDIO_CODECS: astream = output.add_stream(template=audio_stream) return StreamBuffer(segment, output, vstream, astream) @@ -65,8 +68,8 @@ class SegmentBuffer: # Fetch the latest StreamOutputs, which may have changed since the # worker started. self._outputs = [] - for stream_output in self._outputs_callback().values(): - if self._video_stream.name not in stream_output.video_codecs: + for stream_output in self._outputs_callback(): + if self._video_stream.name not in OUTPUT_VIDEO_CODECS: continue buffer = create_stream_buffer( stream_output, self._video_stream, self._audio_stream, self._sequence diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 7811cac2a2a..2a53e2c5169 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -45,7 +45,7 @@ def hls_stream(hass, hass_client): async def create_client_for_stream(stream): http_client = await hass_client() - parsed_url = urlparse(stream.endpoint_url("hls")) + parsed_url = urlparse(stream.endpoint_url()) return HlsClient(http_client, parsed_url) return create_client_for_stream @@ -87,7 +87,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, source) # Request stream - stream.add_provider("hls") + stream.hls_output() stream.start() hls_client = await hls_stream(stream) @@ -128,9 +128,9 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync): stream = create_stream(hass, source) # Request stream - stream.add_provider("hls") + stream.hls_output() stream.start() - url = stream.endpoint_url("hls") + url = stream.endpoint_url() http_client = await hass_client() @@ -168,12 +168,10 @@ async def test_stream_ended(hass, stream_worker_sync): # Setup demo HLS track source = generate_h264_video() stream = create_stream(hass, source) - track = stream.add_provider("hls") # Request stream - stream.add_provider("hls") + track = stream.hls_output() stream.start() - stream.endpoint_url("hls") # Run it dead while True: @@ -199,7 +197,7 @@ async def test_stream_keepalive(hass): # Setup demo HLS track source = "test_stream_keepalive_source" stream = create_stream(hass, source) - track = stream.add_provider("hls") + track = stream.hls_output() track.num_segments = 2 stream.start() @@ -230,12 +228,12 @@ async def test_stream_keepalive(hass): stream.stop() -async def test_hls_playlist_view_no_output(hass, hass_client, hls_stream): +async def test_hls_playlist_view_no_output(hass, hls_stream): """Test rendering the hls playlist with no output segments.""" await async_setup_component(hass, "stream", {"stream": {}}) stream = create_stream(hass, STREAM_SOURCE) - stream.add_provider("hls") + stream.hls_output() hls_client = await hls_stream(stream) @@ -250,7 +248,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, STREAM_SOURCE) stream_worker_sync.pause() - hls = stream.add_provider("hls") + hls = stream.hls_output() hls.put(Segment(1, SEQUENCE_BYTES, DURATION)) await hass.async_block_till_done() @@ -277,7 +275,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, STREAM_SOURCE) stream_worker_sync.pause() - hls = stream.add_provider("hls") + hls = stream.hls_output() hls_client = await hls_stream(stream) diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index 9d418c360b1..3930a5e237d 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -1,10 +1,12 @@ """The tests for hls streams.""" +import asyncio from datetime import timedelta import logging import os import threading from unittest.mock import patch +import async_timeout import av import pytest @@ -32,23 +34,30 @@ class SaveRecordWorkerSync: def __init__(self): """Initialize SaveRecordWorkerSync.""" self.reset() + self._segments = None - def recorder_save_worker(self, *args, **kwargs): + def recorder_save_worker(self, file_out, segments, container_format): """Mock method for patch.""" logging.debug("recorder_save_worker thread started") + self._segments = segments assert self._save_thread is None self._save_thread = threading.current_thread() self._save_event.set() + async def get_segments(self): + """Verify save worker thread was invoked and return saved segments.""" + with async_timeout.timeout(TEST_TIMEOUT): + assert await self._save_event.wait() + return self._segments + def join(self): - """Verify save worker was invoked and block on shutdown.""" - assert self._save_event.wait(timeout=TEST_TIMEOUT) + """Block until the record worker thread exist to ensure cleanup.""" self._save_thread.join() def reset(self): """Reset callback state for reuse in tests.""" self._save_thread = None - self._save_event = threading.Event() + self._save_event = asyncio.Event() @pytest.fixture() @@ -63,7 +72,7 @@ def record_worker_sync(hass): yield sync -async def test_record_stream(hass, hass_client, stream_worker_sync, record_worker_sync): +async def test_record_stream(hass, hass_client, record_worker_sync): """ Test record stream. @@ -73,28 +82,14 @@ async def test_record_stream(hass, hass_client, stream_worker_sync, record_worke """ await async_setup_component(hass, "stream", {"stream": {}}) - stream_worker_sync.pause() - # Setup demo track source = generate_h264_video() stream = create_stream(hass, source) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") - recorder = stream.add_provider("recorder") - while True: - segment = await recorder.recv() - if not segment: - break - segments = segment.sequence - if segments > 1: - stream_worker_sync.resume() - - stream.stop() - assert segments > 1 - - # Verify that the save worker was invoked, then block until its - # thread completes and is shutdown completely to avoid thread leaks. + segments = await record_worker_sync.get_segments() + assert len(segments) > 1 record_worker_sync.join() @@ -107,19 +102,24 @@ async def test_record_lookback( source = generate_h264_video() stream = create_stream(hass, source) + # Don't let the stream finish (and clean itself up) until the test has had + # a chance to perform lookback + stream_worker_sync.pause() + # Start an HLS feed to enable lookback - stream.add_provider("hls") - stream.start() + stream.hls_output() with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path", lookback=4) # This test does not need recorder cleanup since it is not fully exercised - + stream_worker_sync.resume() stream.stop() -async def test_recorder_timeout(hass, hass_client, stream_worker_sync): +async def test_recorder_timeout( + hass, hass_client, stream_worker_sync, record_worker_sync +): """ Test recorder timeout. @@ -137,9 +137,8 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync): stream = create_stream(hass, source) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") - recorder = stream.add_provider("recorder") - await recorder.recv() + assert not mock_timeout.called # Wait a minute future = dt_util.utcnow() + timedelta(minutes=1) @@ -149,9 +148,11 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync): assert mock_timeout.called stream_worker_sync.resume() + # Verify worker is invoked, and do clean shutdown of worker thread + await record_worker_sync.get_segments() + record_worker_sync.join() + stream.stop() - await hass.async_block_till_done() - await hass.async_block_till_done() async def test_record_path_not_allowed(hass, hass_client): @@ -180,9 +181,7 @@ async def test_recorder_save(tmpdir): assert os.path.exists(filename) -async def test_record_stream_audio( - hass, hass_client, stream_worker_sync, record_worker_sync -): +async def test_record_stream_audio(hass, hass_client, record_worker_sync): """ Test treatment of different audio inputs. @@ -198,7 +197,6 @@ async def test_record_stream_audio( (None, 0), # no audio stream ): record_worker_sync.reset() - stream_worker_sync.pause() # Setup demo track source = generate_h264_video( @@ -207,22 +205,14 @@ async def test_record_stream_audio( stream = create_stream(hass, source) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") - recorder = stream.add_provider("recorder") - while True: - segment = await recorder.recv() - if not segment: - break - last_segment = segment - stream_worker_sync.resume() + segments = await record_worker_sync.get_segments() + last_segment = segments[-1] result = av.open(last_segment.segment, "r", format="mp4") assert len(result.streams.audio) == expected_audio_streams result.close() - stream.stop() - await hass.async_block_till_done() - # Verify that the save worker was invoked, then block until its - # thread completes and is shutdown completely to avoid thread leaks. + stream.stop() record_worker_sync.join() diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index b348d68fc86..d9006c81ad5 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -31,7 +31,6 @@ from homeassistant.components.stream.worker import stream_worker STREAM_SOURCE = "some-stream-source" # Formats here are arbitrary, not exercised by tests -STREAM_OUTPUT_FORMAT = "hls" AUDIO_STREAM_FORMAT = "mp3" VIDEO_STREAM_FORMAT = "h264" VIDEO_FRAME_RATE = 12 @@ -188,7 +187,7 @@ class MockPyAv: async def async_decode_stream(hass, packets, py_av=None): """Start a stream worker that decodes incoming stream packets into output segments.""" stream = Stream(hass, STREAM_SOURCE) - stream.add_provider(STREAM_OUTPUT_FORMAT) + stream.hls_output() if not py_av: py_av = MockPyAv() @@ -207,7 +206,7 @@ async def async_decode_stream(hass, packets, py_av=None): async def test_stream_open_fails(hass): """Test failure on stream open.""" stream = Stream(hass, STREAM_SOURCE) - stream.add_provider(STREAM_OUTPUT_FORMAT) + stream.hls_output() with patch("av.open") as av_open: av_open.side_effect = av.error.InvalidDataError(-2, "error") stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event()) @@ -483,7 +482,7 @@ async def test_stream_stopped_while_decoding(hass): worker_wake = threading.Event() stream = Stream(hass, STREAM_SOURCE) - stream.add_provider(STREAM_OUTPUT_FORMAT) + stream.hls_output() py_av = MockPyAv() py_av.container.packets = PacketSequence(TEST_SEQUENCE_LENGTH) @@ -510,7 +509,7 @@ async def test_update_stream_source(hass): worker_wake = threading.Event() stream = Stream(hass, STREAM_SOURCE) - stream.add_provider(STREAM_OUTPUT_FORMAT) + stream.hls_output() # Note that keepalive is not set here. The stream is "restarted" even though # it is not stopping due to failure.