diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index c7ca853c20c..1d3a46d0273 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -25,24 +25,33 @@ import time from types import MappingProxyType from typing import cast +import voluptuous as vol + from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError +import homeassistant.helpers.config_validation as cv from homeassistant.helpers.typing import ConfigType from .const import ( ATTR_ENDPOINTS, + ATTR_SETTINGS, ATTR_STREAMS, + CONF_LL_HLS, + CONF_PART_DURATION, + CONF_SEGMENT_DURATION, DOMAIN, HLS_PROVIDER, MAX_SEGMENTS, OUTPUT_IDLE_TIMEOUT, RECORDER_PROVIDER, + SEGMENT_DURATION_ADJUSTER, STREAM_RESTART_INCREMENT, STREAM_RESTART_RESET_TIME, + TARGET_SEGMENT_DURATION_NON_LL_HLS, ) -from .core import PROVIDERS, IdleTimer, StreamOutput -from .hls import async_setup_hls +from .core import PROVIDERS, IdleTimer, StreamOutput, StreamSettings +from .hls import HlsStreamOutput, async_setup_hls _LOGGER = logging.getLogger(__name__) @@ -78,6 +87,24 @@ def create_stream( return stream +CONFIG_SCHEMA = vol.Schema( + { + DOMAIN: vol.Schema( + { + vol.Optional(CONF_LL_HLS, default=False): cv.boolean, + vol.Optional(CONF_SEGMENT_DURATION, default=6): vol.All( + cv.positive_float, vol.Range(min=2, max=10) + ), + vol.Optional(CONF_PART_DURATION, default=1): vol.All( + cv.positive_float, vol.Range(min=0.2, max=1.5) + ), + } + ) + }, + extra=vol.ALLOW_EXTRA, +) + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up stream.""" # Set log level to error for libav @@ -91,6 +118,26 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.data[DOMAIN] = {} hass.data[DOMAIN][ATTR_ENDPOINTS] = {} hass.data[DOMAIN][ATTR_STREAMS] = [] + if (conf := config.get(DOMAIN)) and conf[CONF_LL_HLS]: + assert isinstance(conf[CONF_SEGMENT_DURATION], float) + assert isinstance(conf[CONF_PART_DURATION], float) + hass.data[DOMAIN][ATTR_SETTINGS] = StreamSettings( + ll_hls=True, + min_segment_duration=conf[CONF_SEGMENT_DURATION] + - SEGMENT_DURATION_ADJUSTER, + part_target_duration=conf[CONF_PART_DURATION], + hls_advance_part_limit=max(int(3 / conf[CONF_PART_DURATION]), 3), + hls_part_timeout=2 * conf[CONF_PART_DURATION], + ) + else: + hass.data[DOMAIN][ATTR_SETTINGS] = StreamSettings( + ll_hls=False, + min_segment_duration=TARGET_SEGMENT_DURATION_NON_LL_HLS + - SEGMENT_DURATION_ADJUSTER, + part_target_duration=TARGET_SEGMENT_DURATION_NON_LL_HLS, + hls_advance_part_limit=3, + hls_part_timeout=TARGET_SEGMENT_DURATION_NON_LL_HLS, + ) # Setup HLS hls_endpoint = async_setup_hls(hass) @@ -206,11 +253,16 @@ class Stream: # pylint: disable=import-outside-toplevel from .worker import SegmentBuffer, stream_worker - segment_buffer = SegmentBuffer(self.outputs) + segment_buffer = SegmentBuffer(self.hass, self.outputs) wait_timeout = 0 while not self._thread_quit.wait(timeout=wait_timeout): start_time = time.time() - stream_worker(self.source, self.options, segment_buffer, self._thread_quit) + stream_worker( + self.source, + self.options, + segment_buffer, + self._thread_quit, + ) segment_buffer.discontinuity() if not self.keepalive or self._thread_quit.is_set(): if self._fast_restart_once: @@ -288,7 +340,7 @@ class Stream: _LOGGER.debug("Started a stream recording of %s seconds", duration) # Take advantage of lookback - hls = self.outputs().get(HLS_PROVIDER) + hls: HlsStreamOutput = cast(HlsStreamOutput, self.outputs().get(HLS_PROVIDER)) if lookback > 0 and hls: num_segments = min(int(lookback // hls.target_duration), MAX_SEGMENTS) # Wait for latest segment, then add the lookback diff --git a/homeassistant/components/stream/const.py b/homeassistant/components/stream/const.py index cf4a80d9705..50ae43df0d0 100644 --- a/homeassistant/components/stream/const.py +++ b/homeassistant/components/stream/const.py @@ -2,6 +2,7 @@ DOMAIN = "stream" ATTR_ENDPOINTS = "endpoints" +ATTR_SETTINGS = "settings" ATTR_STREAMS = "streams" HLS_PROVIDER = "hls" @@ -19,16 +20,15 @@ OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity NUM_PLAYLIST_SEGMENTS = 3 # Number of segments to use in HLS playlist MAX_SEGMENTS = 5 # Max number of segments to keep around -TARGET_SEGMENT_DURATION = 2.0 # Each segment is about this many seconds -TARGET_PART_DURATION = 1.0 +TARGET_SEGMENT_DURATION_NON_LL_HLS = 2.0 # Each segment is about this many seconds SEGMENT_DURATION_ADJUSTER = 0.1 # Used to avoid missing keyframe boundaries -# Each segment is at least this many seconds -MIN_SEGMENT_DURATION = TARGET_SEGMENT_DURATION - SEGMENT_DURATION_ADJUSTER - # Number of target durations to start before the end of the playlist. # 1.5 should put us in the middle of the second to last segment even with # variable keyframe intervals. -EXT_X_START = 1.5 +EXT_X_START_NON_LL_HLS = 1.5 +# Number of part durations to start before the end of the playlist with LL-HLS +EXT_X_START_LL_HLS = 2 + PACKETS_TO_WAIT_FOR_AUDIO = 20 # Some streams have an audio stream with no audio MAX_TIMESTAMP_GAP = 10000 # seconds - anything from 10 to 50000 is probably reasonable @@ -38,3 +38,7 @@ SOURCE_TIMEOUT = 30 # Timeout for reading stream source STREAM_RESTART_INCREMENT = 10 # Increase wait_timeout by this amount each retry STREAM_RESTART_RESET_TIME = 300 # Reset wait_timeout after this many seconds + +CONF_LL_HLS = "ll_hls" +CONF_PART_DURATION = "part_duration" +CONF_SEGMENT_DURATION = "segment_duration" diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index d840bfaf858..77e41511b92 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -3,10 +3,13 @@ from __future__ import annotations import asyncio from collections import deque +from collections.abc import Generator, Iterable import datetime +import itertools from typing import TYPE_CHECKING from aiohttp import web +import async_timeout import attr from homeassistant.components.http.view import HomeAssistantView @@ -14,7 +17,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers.event import async_call_later from homeassistant.util.decorator import Registry -from .const import ATTR_STREAMS, DOMAIN, TARGET_SEGMENT_DURATION +from .const import ATTR_STREAMS, DOMAIN if TYPE_CHECKING: from . import Stream @@ -22,6 +25,17 @@ if TYPE_CHECKING: PROVIDERS = Registry() +@attr.s(slots=True) +class StreamSettings: + """Stream settings.""" + + ll_hls: bool = attr.ib() + min_segment_duration: float = attr.ib() + part_target_duration: float = attr.ib() + hls_advance_part_limit: int = attr.ib() + hls_part_timeout: float = attr.ib() + + @attr.s(slots=True) class Part: """Represent a segment part.""" @@ -36,23 +50,170 @@ class Part: class Segment: """Represent a segment.""" - sequence: int = attr.ib(default=0) + sequence: int = attr.ib() # the init of the mp4 the segment is based on - init: bytes = attr.ib(default=None) - duration: float = attr.ib(default=0) + init: bytes = attr.ib() # For detecting discontinuities across stream restarts - stream_id: int = attr.ib(default=0) - parts: list[Part] = attr.ib(factory=list) - start_time: datetime.datetime = attr.ib(factory=datetime.datetime.utcnow) + stream_id: int = attr.ib() + start_time: datetime.datetime = attr.ib() + _stream_outputs: Iterable[StreamOutput] = attr.ib() + duration: float = attr.ib(default=0) + # Parts are stored in a dict indexed by byterange for easy lookup + # As of Python 3.7, insertion order is preserved, and we insert + # in sequential order, so the Parts are ordered + parts_by_byterange: dict[int, Part] = attr.ib(factory=dict) + # Store text of this segment's hls playlist for reuse + # Use list[str] for easy appends + hls_playlist_template: list[str] = attr.ib(factory=list) + hls_playlist_parts: list[str] = attr.ib(factory=list) + # Number of playlist parts rendered so far + hls_num_parts_rendered: int = attr.ib(default=0) + # Set to true when all the parts are rendered + hls_playlist_complete: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + """Run after init.""" + for output in self._stream_outputs: + output.put(self) @property def complete(self) -> bool: """Return whether the Segment is complete.""" return self.duration > 0 - def get_bytes_without_init(self) -> bytes: + @property + def data_size_with_init(self) -> int: + """Return the size of all part data + init in bytes.""" + return len(self.init) + self.data_size + + @property + def data_size(self) -> int: + """Return the size of all part data without init in bytes.""" + # We can use the last part to quickly calculate the total data size. + if not self.parts_by_byterange: + return 0 + last_http_range_start, last_part = next( + reversed(self.parts_by_byterange.items()) + ) + return last_http_range_start + len(last_part.data) + + @callback + def async_add_part( + self, + part: Part, + duration: float, + ) -> None: + """Add a part to the Segment. + + Duration is non zero only for the last part. + """ + self.parts_by_byterange[self.data_size] = part + self.duration = duration + for output in self._stream_outputs: + output.part_put() + + def get_data(self) -> bytes: """Return reconstructed data for all parts as bytes, without init.""" - return b"".join([part.data for part in self.parts]) + return b"".join([part.data for part in self.parts_by_byterange.values()]) + + def get_aggregating_bytes( + self, start_loc: int, end_loc: int | float + ) -> Generator[bytes, None, None]: + """Yield available remaining data until segment is complete or end_loc is reached. + + Begin at start_loc. End at end_loc (exclusive). + Used to help serve a range request on a segment. + """ + pos = start_loc + while (part := self.parts_by_byterange.get(pos)) or not self.complete: + if not part: + yield b"" + continue + pos += len(part.data) + # Check stopping condition and trim output if necessary + if pos >= end_loc: + assert isinstance(end_loc, int) + # Trimming is probably not necessary, but it doesn't hurt + yield part.data[: len(part.data) + end_loc - pos] + return + yield part.data + + def _render_hls_template(self, last_stream_id: int, render_parts: bool) -> str: + """Render the HLS playlist section for the Segment. + + The Segment may still be in progress. + This method stores intermediate data in hls_playlist_parts, hls_num_parts_rendered, + and hls_playlist_complete to avoid redoing work on subsequent calls. + """ + if self.hls_playlist_complete: + return self.hls_playlist_template[0] + if not self.hls_playlist_template: + # This is a placeholder where the rendered parts will be inserted + self.hls_playlist_template.append("{}") + if render_parts: + for http_range_start, part in itertools.islice( + self.parts_by_byterange.items(), + self.hls_num_parts_rendered, + None, + ): + self.hls_playlist_parts.append( + f"#EXT-X-PART:DURATION={part.duration:.3f},URI=" + f'"./segment/{self.sequence}.m4s",BYTERANGE="{len(part.data)}' + f'@{http_range_start}"{",INDEPENDENT=YES" if part.has_keyframe else ""}' + ) + if self.complete: + # Construct the final playlist_template. The placeholder will share a line with + # the first element to avoid an extra newline when we don't render any parts. + # Append an empty string to create a trailing newline when we do render parts + self.hls_playlist_parts.append("") + self.hls_playlist_template = [] + # Logically EXT-X-DISCONTINUITY would make sense above the parts, but Apple's + # media stream validator seems to only want it before the segment + if last_stream_id != self.stream_id: + self.hls_playlist_template.append("#EXT-X-DISCONTINUITY") + # Add the remaining segment metadata + self.hls_playlist_template.extend( + [ + "#EXT-X-PROGRAM-DATE-TIME:" + + self.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z", + f"#EXTINF:{self.duration:.3f},\n./segment/{self.sequence}.m4s", + ] + ) + # The placeholder now goes on the same line as the first element + self.hls_playlist_template[0] = "{}" + self.hls_playlist_template[0] + + # Store intermediate playlist data in member variables for reuse + self.hls_playlist_template = ["\n".join(self.hls_playlist_template)] + # lstrip discards extra preceding newline in case first render was empty + self.hls_playlist_parts = ["\n".join(self.hls_playlist_parts).lstrip()] + self.hls_num_parts_rendered = len(self.parts_by_byterange) + self.hls_playlist_complete = self.complete + + return self.hls_playlist_template[0] + + def render_hls( + self, last_stream_id: int, render_parts: bool, add_hint: bool + ) -> str: + """Render the HLS playlist section for the Segment including a hint if requested.""" + playlist_template = self._render_hls_template(last_stream_id, render_parts) + playlist = playlist_template.format( + self.hls_playlist_parts[0] if render_parts else "" + ) + if not add_hint: + return playlist + # Preload hints help save round trips by informing the client about the next part. + # The next part will usually be in this segment but will be first part of the next + # segment if this segment is already complete. + # pylint: disable=undefined-loop-variable + if self.complete: # Next part belongs to next segment + sequence = self.sequence + 1 + start = 0 + else: # Next part is in the same segment + sequence = self.sequence + start = self.data_size + hint = f'#EXT-X-PRELOAD-HINT:TYPE=PART,URI="./segment/{sequence}.m4s",BYTERANGE-START={start}' + return (playlist + "\n" + hint) if playlist else hint class IdleTimer: @@ -110,6 +271,7 @@ class StreamOutput: self._hass = hass self.idle_timer = idle_timer self._event = asyncio.Event() + self._part_event = asyncio.Event() self._segments: deque[Segment] = deque(maxlen=deque_maxlen) @property @@ -141,13 +303,6 @@ class StreamOutput: return self._segments[-1] return None - @property - def target_duration(self) -> float: - """Return the max duration of any given segment in seconds.""" - if not (durations := [s.duration for s in self._segments if s.complete]): - return TARGET_SEGMENT_DURATION - return max(durations) - def get_segment(self, sequence: int) -> Segment | None: """Retrieve a specific segment.""" # Most hits will come in the most recent segments, so iterate reversed @@ -160,8 +315,23 @@ class StreamOutput: """Retrieve all segments.""" return self._segments + async def part_recv(self, timeout: float | None = None) -> bool: + """Wait for an event signalling the latest part segment.""" + try: + async with async_timeout.timeout(timeout): + await self._part_event.wait() + except asyncio.TimeoutError: + return False + return True + + def part_put(self) -> None: + """Set event signalling the latest part segment.""" + # Start idle timeout when we start receiving data + self._part_event.set() + self._part_event.clear() + async def recv(self) -> bool: - """Wait for and retrieve the latest segment.""" + """Wait for the latest segment.""" await self._event.wait() return self.last_segment is not None diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index 7f11bc09655..9b154e9236b 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -1,25 +1,31 @@ """Provide functionality to stream HLS.""" from __future__ import annotations -from typing import TYPE_CHECKING +import logging +from typing import TYPE_CHECKING, cast from aiohttp import web from homeassistant.core import HomeAssistant, callback from .const import ( - EXT_X_START, + ATTR_SETTINGS, + DOMAIN, + EXT_X_START_LL_HLS, + EXT_X_START_NON_LL_HLS, FORMAT_CONTENT_TYPE, HLS_PROVIDER, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS, ) -from .core import PROVIDERS, IdleTimer, StreamOutput, StreamView +from .core import PROVIDERS, IdleTimer, StreamOutput, StreamSettings, StreamView from .fmp4utils import get_codec_string if TYPE_CHECKING: from . import Stream +_LOGGER = logging.getLogger(__name__) + @callback def async_setup_hls(hass: HomeAssistant) -> str: @@ -31,6 +37,38 @@ def async_setup_hls(hass: HomeAssistant) -> str: return "/api/hls/{}/master_playlist.m3u8" +@PROVIDERS.register(HLS_PROVIDER) +class HlsStreamOutput(StreamOutput): + """Represents HLS Output formats.""" + + def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: + """Initialize HLS output.""" + super().__init__(hass, idle_timer, deque_maxlen=MAX_SEGMENTS) + self.stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS] + self._target_duration = 0.0 + + @property + def name(self) -> str: + """Return provider name.""" + return HLS_PROVIDER + + @property + def target_duration(self) -> float: + """ + Return the target duration. + + The target duration is calculated as the max duration of any given segment, + and it is calculated only one time to avoid changing during playback. + """ + if self._target_duration: + return self._target_duration + durations = [s.duration for s in self._segments if s.complete] + if len(durations) < 2: + return self.stream_settings.min_segment_duration + self._target_duration = max(durations) + return self._target_duration + + class HlsMasterPlaylistView(StreamView): """Stream view used only for Chromecast compatibility.""" @@ -46,12 +84,7 @@ class HlsMasterPlaylistView(StreamView): # hls spec already allows for 25% variation if not (segment := track.get_segment(track.sequences[-2])): return "" - bandwidth = round( - (len(segment.init) + sum(len(part.data) for part in segment.parts)) - * 8 - / segment.duration - * 1.2 - ) + bandwidth = round(segment.data_size_with_init * 8 / segment.duration * 1.2) codecs = get_codec_string(segment.init) lines = [ "#EXTM3U", @@ -71,8 +104,14 @@ class HlsMasterPlaylistView(StreamView): return web.HTTPNotFound() if len(track.sequences) == 1 and not await track.recv(): return web.HTTPNotFound() - headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]} - return web.Response(body=self.render(track).encode("utf-8"), headers=headers) + response = web.Response( + body=self.render(track).encode("utf-8"), + headers={ + "Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER], + }, + ) + response.enable_compression(web.ContentCoding.gzip) + return response class HlsPlaylistView(StreamView): @@ -82,9 +121,9 @@ class HlsPlaylistView(StreamView): name = "api:stream:hls:playlist" cors_allowed = True - @staticmethod - def render(track: StreamOutput) -> str: - """Render playlist.""" + @classmethod + def render(cls, track: HlsStreamOutput) -> str: + """Render HLS playlist file.""" # NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :] @@ -102,9 +141,17 @@ class HlsPlaylistView(StreamView): f"#EXT-X-TARGETDURATION:{track.target_duration:.0f}", f"#EXT-X-MEDIA-SEQUENCE:{first_segment.sequence}", f"#EXT-X-DISCONTINUITY-SEQUENCE:{first_segment.stream_id}", - "#EXT-X-PROGRAM-DATE-TIME:" - + first_segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] - + "Z", + ] + + if track.stream_settings.ll_hls: + playlist.extend( + [ + f"#EXT-X-PART-INF:PART-TARGET={track.stream_settings.part_target_duration:.3f}", + f"#EXT-X-SERVER-CONTROL:CAN-BLOCK-RELOAD=YES,PART-HOLD-BACK={2*track.stream_settings.part_target_duration:.3f}", + f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START_LL_HLS*track.stream_settings.part_target_duration:.3f},PRECISE=YES", + ] + ) + else: # Since our window doesn't have many segments, we don't want to start # at the beginning or we risk a behind live window exception in Exoplayer. # EXT-X-START is not supposed to be within 3 target durations of the end, @@ -113,47 +160,147 @@ class HlsPlaylistView(StreamView): # don't autoplay. Also, hls.js uses the player parameter liveSyncDuration # which seems to take precedence for setting target delay. Yet it also # doesn't seem to hurt, so we can stick with it for now. - f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f}", - ] + playlist.append( + f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START_NON_LL_HLS*track.target_duration:.3f},PRECISE=YES" + ) last_stream_id = first_segment.stream_id - # Add playlist sections - for segment in segments: - # Skip last segment if it is not complete - if segment.complete: - if last_stream_id != segment.stream_id: - playlist.extend( - [ - "#EXT-X-DISCONTINUITY", - "#EXT-X-PROGRAM-DATE-TIME:" - + segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] - + "Z", - ] - ) - playlist.extend( - [ - f"#EXTINF:{segment.duration:.3f},", - f"./segment/{segment.sequence}.m4s", - ] + + # Add playlist sections for completed segments + # Enumeration used to only include EXT-X-PART data for last 3 segments. + # The RFC seems to suggest removing parts after 3 full segments, but Apple's + # own example shows removing after 2 full segments and 1 part one. + for i, segment in enumerate(segments[:-1], 3 - len(segments)): + playlist.append( + segment.render_hls( + last_stream_id=last_stream_id, + render_parts=i >= 0 and track.stream_settings.ll_hls, + add_hint=False, ) - last_stream_id = segment.stream_id + ) + last_stream_id = segment.stream_id + + playlist.append( + segments[-1].render_hls( + last_stream_id=last_stream_id, + render_parts=track.stream_settings.ll_hls, + add_hint=track.stream_settings.ll_hls, + ) + ) return "\n".join(playlist) + "\n" + @staticmethod + def bad_request(blocking: bool, target_duration: float) -> web.Response: + """Return a HTTP Bad Request response.""" + return web.Response( + body=None, + status=400, + # From Appendix B.1 of the RFC: + # Successful responses to blocking Playlist requests should be cached + # for six Target Durations. Unsuccessful responses (such as 404s) should + # be cached for four Target Durations. Successful responses to non-blocking + # Playlist requests should be cached for half the Target Duration. + # Unsuccessful responses to non-blocking Playlist requests should be + # cached for for one Target Duration. + headers={ + "Cache-Control": f"max-age={(4 if blocking else 1)*target_duration:.0f}" + }, + ) + + @staticmethod + def not_found(blocking: bool, target_duration: float) -> web.Response: + """Return a HTTP Not Found response.""" + return web.Response( + body=None, + status=404, + headers={ + "Cache-Control": f"max-age={(4 if blocking else 1)*target_duration:.0f}" + }, + ) + async def handle( self, request: web.Request, stream: Stream, sequence: str ) -> web.Response: """Return m3u8 playlist.""" - track = stream.add_provider(HLS_PROVIDER) + track: HlsStreamOutput = cast( + HlsStreamOutput, stream.add_provider(HLS_PROVIDER) + ) stream.start() - # Make sure at least two segments are ready (last one may not be complete) - if not track.sequences and not await track.recv(): - return web.HTTPNotFound() - if len(track.sequences) == 1 and not await track.recv(): - return web.HTTPNotFound() - headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]} + + hls_msn: str | int | None = request.query.get("_HLS_msn") + hls_part: str | int | None = request.query.get("_HLS_part") + blocking_request = bool(hls_msn or hls_part) + + # If the Playlist URI contains an _HLS_part directive but no _HLS_msn + # directive, the Server MUST return Bad Request, such as HTTP 400. + if hls_msn is None and hls_part: + return web.HTTPBadRequest() + + hls_msn = int(hls_msn or 0) + + # If the _HLS_msn is greater than the Media Sequence Number of the last + # Media Segment in the current Playlist plus two, or if the _HLS_part + # exceeds the last Part Segment in the current Playlist by the + # Advance Part Limit, then the server SHOULD immediately return Bad + # Request, such as HTTP 400. + if hls_msn > track.last_sequence + 2: + return self.bad_request(blocking_request, track.target_duration) + + if hls_part is None: + # We need to wait for the whole segment, so effectively the next msn + hls_part = -1 + hls_msn += 1 + else: + hls_part = int(hls_part) + + while hls_msn > track.last_sequence: + if not await track.recv(): + return self.not_found(blocking_request, track.target_duration) + if track.last_segment is None: + return self.not_found(blocking_request, 0) + if ( + (last_segment := track.last_segment) + and hls_msn == last_segment.sequence + and hls_part + >= len(last_segment.parts_by_byterange) + - 1 + + track.stream_settings.hls_advance_part_limit + ): + return self.bad_request(blocking_request, track.target_duration) + + # Receive parts until msn and part are met + while ( + (last_segment := track.last_segment) + and hls_msn == last_segment.sequence + and hls_part >= len(last_segment.parts_by_byterange) + ): + if not await track.part_recv( + timeout=track.stream_settings.hls_part_timeout + ): + return self.not_found(blocking_request, track.target_duration) + # Now we should have msn.part >= hls_msn.hls_part. However, in the case + # that we have a rollover part request from the previous segment, we need + # to make sure that the new segment has a part. From 6.2.5.2 of the RFC: + # If the Client requests a Part Index greater than that of the final + # Partial Segment of the Parent Segment, the Server MUST treat the + # request as one for Part Index 0 of the following Parent Segment. + if hls_msn + 1 == last_segment.sequence: + if not (previous_segment := track.get_segment(hls_msn)) or ( + hls_part >= len(previous_segment.parts_by_byterange) + and not last_segment.parts_by_byterange + and not await track.part_recv( + timeout=track.stream_settings.hls_part_timeout + ) + ): + return self.not_found(blocking_request, track.target_duration) + response = web.Response( - body=self.render(track).encode("utf-8"), headers=headers + body=self.render(track).encode("utf-8"), + headers={ + "Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER], + "Cache-Control": f"max-age={(6 if blocking_request else 0.5)*track.target_duration:.0f}", + }, ) response.enable_compression(web.ContentCoding.gzip) return response @@ -171,10 +318,11 @@ class HlsInitView(StreamView): ) -> web.Response: """Return init.mp4.""" track = stream.add_provider(HLS_PROVIDER) - if not (segments := track.get_segments()): + if not (segments := track.get_segments()) or not (body := segments[0].init): return web.HTTPNotFound() return web.Response( - body=segments[0].init, headers={"Content-Type": "video/mp4"} + body=body, + headers={"Content-Type": "video/mp4"}, ) @@ -187,28 +335,102 @@ class HlsSegmentView(StreamView): async def handle( self, request: web.Request, stream: Stream, sequence: str - ) -> web.Response: - """Return fmp4 segment.""" - track = stream.add_provider(HLS_PROVIDER) - track.idle_timer.awake() - if not (segment := track.get_segment(int(sequence))): - return web.HTTPNotFound() - headers = {"Content-Type": "video/iso.segment"} - return web.Response( - body=segment.get_bytes_without_init(), - headers=headers, + ) -> web.StreamResponse: + """Handle segments, part segments, and hinted segments. + + For part and hinted segments, the start of the requested range must align + with a part boundary. + """ + track: HlsStreamOutput = cast( + HlsStreamOutput, stream.add_provider(HLS_PROVIDER) ) - - -@PROVIDERS.register(HLS_PROVIDER) -class HlsStreamOutput(StreamOutput): - """Represents HLS Output formats.""" - - def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: - """Initialize recorder output.""" - super().__init__(hass, idle_timer, deque_maxlen=MAX_SEGMENTS) - - @property - def name(self) -> str: - """Return provider name.""" - return HLS_PROVIDER + track.idle_timer.awake() + # Ensure that we have a segment. If the request is from a hint for part 0 + # of a segment, there is a small chance it may have arrived before the + # segment has been put. If this happens, wait for one part and retry. + if not ( + (segment := track.get_segment(int(sequence))) + or ( + await track.part_recv(timeout=track.stream_settings.hls_part_timeout) + and (segment := track.get_segment(int(sequence))) + ) + ): + return web.Response( + body=None, + status=404, + headers={"Cache-Control": f"max-age={track.target_duration:.0f}"}, + ) + # If the segment is ready or has been hinted, the http_range start should be at most + # equal to the end of the currently available data. + # If the segment is complete, the http_range start should be less than the end of the + # currently available data. + # If these conditions aren't met then we return a 416. + # http_range_start can be None, so use a copy that uses 0 instead of None + if (http_start := request.http_range.start or 0) > segment.data_size or ( + segment.complete and http_start >= segment.data_size + ): + return web.HTTPRequestRangeNotSatisfiable( + headers={ + "Cache-Control": f"max-age={track.target_duration:.0f}", + "Content-Range": f"bytes */{segment.data_size}", + } + ) + headers = { + "Content-Type": "video/iso.segment", + "Cache-Control": f"max-age={6*track.target_duration:.0f}", + } + # For most cases we have a 206 partial content response. + status = 206 + # For the 206 responses we need to set a Content-Range header + # See https://datatracker.ietf.org/doc/html/rfc8673#section-2 + if request.http_range.stop is None: + if request.http_range.start is None: + status = 200 + if segment.complete: + # This is a request for a full segment which is already complete + # We should return a standard 200 response. + return web.Response( + body=segment.get_data(), headers=headers, status=status + ) + # Otherwise we still return a 200 response, but it is aggregating + http_stop = float("inf") + else: + # See https://datatracker.ietf.org/doc/html/rfc7233#section-2.1 + headers[ + "Content-Range" + ] = f"bytes {http_start}-{(http_stop:=segment.data_size)-1}/*" + else: # The remaining cases are all 206 responses + if segment.complete: + # If the segment is complete we have total size + headers["Content-Range"] = ( + f"bytes {http_start}-" + + str( + (http_stop := min(request.http_range.stop, segment.data_size)) + - 1 + ) + + f"/{segment.data_size}" + ) + else: + # If we don't have the total size we use a * + headers[ + "Content-Range" + ] = f"bytes {http_start}-{(http_stop:=request.http_range.stop)-1}/*" + # Set up streaming response that we can write to as data becomes available + response = web.StreamResponse(headers=headers, status=status) + # Waiting until we write to prepare *might* give clients more accurate TTFB + # and ABR measurements, but it is probably not very useful for us since we + # only have one rendition anyway. Just prepare here for now. + await response.prepare(request) + try: + for bytes_to_write in segment.get_aggregating_bytes( + start_loc=http_start, end_loc=http_stop + ): + if bytes_to_write: + await response.write(bytes_to_write) + elif not await track.part_recv( + timeout=track.stream_settings.hls_part_timeout + ): + break + except ConnectionResetError: + _LOGGER.warning("Connection reset while serving HLS partial segment") + return response diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py index 99276d9763c..2fa612e631c 100644 --- a/homeassistant/components/stream/recorder.py +++ b/homeassistant/components/stream/recorder.py @@ -57,7 +57,7 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]) -> None: # Open segment source = av.open( - BytesIO(segment.init + segment.get_bytes_without_init()), + BytesIO(segment.init + segment.get_data()), "r", format=SEGMENT_CONTAINER_FORMAT, ) diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 039163c6cf5..314e4f33e80 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import defaultdict, deque from collections.abc import Generator, Iterator, Mapping +import datetime from io import BytesIO import logging from threading import Event @@ -10,18 +11,20 @@ from typing import Any, Callable, cast import av +from homeassistant.core import HomeAssistant + from . import redact_credentials from .const import ( + ATTR_SETTINGS, AUDIO_CODECS, + DOMAIN, MAX_MISSING_DTS, MAX_TIMESTAMP_GAP, - MIN_SEGMENT_DURATION, PACKETS_TO_WAIT_FOR_AUDIO, SEGMENT_CONTAINER_FORMAT, SOURCE_TIMEOUT, - TARGET_PART_DURATION, ) -from .core import Part, Segment, StreamOutput +from .core import Part, Segment, StreamOutput, StreamSettings _LOGGER = logging.getLogger(__name__) @@ -30,10 +33,13 @@ class SegmentBuffer: """Buffer for writing a sequence of packets to the output as a segment.""" def __init__( - self, outputs_callback: Callable[[], Mapping[str, StreamOutput]] + self, + hass: HomeAssistant, + outputs_callback: Callable[[], Mapping[str, StreamOutput]], ) -> None: """Initialize SegmentBuffer.""" self._stream_id: int = 0 + self._hass = hass self._outputs_callback: Callable[ [], Mapping[str, StreamOutput] ] = outputs_callback @@ -52,10 +58,14 @@ class SegmentBuffer: self._memory_file_pos: int = cast(int, None) self._part_start_dts: int = cast(int, None) self._part_has_keyframe = False + self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS] + self._start_time = datetime.datetime.utcnow() - @staticmethod def make_new_av( - memory_file: BytesIO, sequence: int, input_vstream: av.video.VideoStream + self, + memory_file: BytesIO, + sequence: int, + input_vstream: av.video.VideoStream, ) -> av.container.OutputContainer: """Make a new av OutputContainer.""" return av.open( @@ -63,19 +73,38 @@ class SegmentBuffer: mode="w", format=SEGMENT_CONTAINER_FORMAT, container_options={ - # Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970 - # "cmaf" flag replaces several of the movflags used, but too recent to use for now - "movflags": "empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer", - # Sometimes the first segment begins with negative timestamps, and this setting just - # adjusts the timestamps in the output from that segment to start from 0. Helps from - # having to make some adjustments in test_durations - "avoid_negative_ts": "make_non_negative", - "fragment_index": str(sequence + 1), - "video_track_timescale": str(int(1 / input_vstream.time_base)), - # Create a fragments every TARGET_PART_DURATION. The data from each fragment is stored in - # a "Part" that can be combined with the data from all the other "Part"s, plus an init - # section, to reconstitute the data in a "Segment". - "frag_duration": str(int(TARGET_PART_DURATION * 1e6)), + **{ + # Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970 + # "cmaf" flag replaces several of the movflags used, but too recent to use for now + "movflags": "frag_custom+empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer", + # Sometimes the first segment begins with negative timestamps, and this setting just + # adjusts the timestamps in the output from that segment to start from 0. Helps from + # having to make some adjustments in test_durations + "avoid_negative_ts": "make_non_negative", + "fragment_index": str(sequence + 1), + "video_track_timescale": str(int(1 / input_vstream.time_base)), + }, + # Only do extra fragmenting if we are using ll_hls + # Let ffmpeg do the work using frag_duration + # Fragment durations may exceed the 15% allowed variance but it seems ok + **( + { + "movflags": "empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer", + # Create a fragment every TARGET_PART_DURATION. The data from each fragment is stored in + # a "Part" that can be combined with the data from all the other "Part"s, plus an init + # section, to reconstitute the data in a "Segment". + # frag_duration seems to be a minimum threshold for determining part boundaries, so some + # parts may have a higher duration. Since Part Target Duration is used in LL-HLS as a + # maximum threshold for part durations, we scale that number down here by .85 and hope + # that the output part durations stay below the maximum Part Target Duration threshold. + # See https://datatracker.ietf.org/doc/html/draft-pantos-hls-rfc8216bis#section-4.4.4.9 + "frag_duration": str( + self._stream_settings.part_target_duration * 1e6 + ), + } + if self._stream_settings.ll_hls + else {} + ), }, ) @@ -120,7 +149,7 @@ class SegmentBuffer: if ( packet.is_keyframe and (packet.dts - self._segment_start_dts) * packet.time_base - >= MIN_SEGMENT_DURATION + >= self._stream_settings.min_segment_duration ): # Flush segment (also flushes the stub part segment) self.flush(packet, last_part=True) @@ -148,13 +177,16 @@ class SegmentBuffer: sequence=self._sequence, stream_id=self._stream_id, init=self._memory_file.getvalue(), + # Fetch the latest StreamOutputs, which may have changed since the + # worker started. + stream_outputs=self._outputs_callback().values(), + start_time=self._start_time + + datetime.timedelta( + seconds=float(self._segment_start_dts * packet.time_base) + ), ) self._memory_file_pos = self._memory_file.tell() self._part_start_dts = self._segment_start_dts - # Fetch the latest StreamOutputs, which may have changed since the - # worker started. - for stream_output in self._outputs_callback().values(): - stream_output.put(self._segment) else: # These are the ends of the part segments self.flush(packet, last_part=False) @@ -164,27 +196,41 @@ class SegmentBuffer: If last_part is True, also close the segment, give it a duration, and clean up the av_output and memory_file. """ + # In some cases using the current packet's dts (which is the start + # dts of the next part) to calculate the part duration will result in a + # value which exceeds the part_target_duration. This can muck up the + # duration of both this part and the next part. An easy fix is to just + # use the current packet dts and cap it by the part target duration. + current_dts = min( + packet.dts, + self._part_start_dts + + self._stream_settings.part_target_duration / packet.time_base, + ) if last_part: # Closing the av_output will write the remaining buffered data to the # memory_file as a new moof/mdat. self._av_output.close() assert self._segment self._memory_file.seek(self._memory_file_pos) - self._segment.parts.append( + self._hass.loop.call_soon_threadsafe( + self._segment.async_add_part, Part( - duration=float((packet.dts - self._part_start_dts) * packet.time_base), + duration=float((current_dts - self._part_start_dts) * packet.time_base), has_keyframe=self._part_has_keyframe, data=self._memory_file.read(), - ) + ), + float((current_dts - self._segment_start_dts) * packet.time_base) + if last_part + else 0, ) if last_part: - self._segment.duration = float( - (packet.dts - self._segment_start_dts) * packet.time_base - ) + # If we've written the last part, we can close the memory_file. self._memory_file.close() # We don't need the BytesIO object anymore else: + # For the last part, these will get set again elsewhere so we can skip + # setting them here. self._memory_file_pos = self._memory_file.tell() - self._part_start_dts = packet.dts + self._part_start_dts = current_dts self._part_has_keyframe = False def discontinuity(self) -> None: diff --git a/tests/components/stream/common.py b/tests/components/stream/common.py index a39e8bdca21..19a4d2a9e6f 100644 --- a/tests/components/stream/common.py +++ b/tests/components/stream/common.py @@ -1,10 +1,25 @@ """Collection of test helpers.""" +from datetime import datetime from fractions import Fraction +from functools import partial import io import av import numpy as np +from homeassistant.components.stream.core import Segment + +FAKE_TIME = datetime.utcnow() +# Segment with defaults filled in for use in tests + +DefaultSegment = partial( + Segment, + init=None, + stream_id=0, + start_time=FAKE_TIME, + stream_outputs=[], +) + AUDIO_SAMPLE_RATE = 8000 @@ -22,14 +37,13 @@ def generate_audio_frame(pcm_mulaw=False): return audio_frame -def generate_h264_video(container_format="mp4"): +def generate_h264_video(container_format="mp4", duration=5): """ Generate a test video. See: http://docs.mikeboers.com/pyav/develop/cookbook/numpy.html """ - duration = 5 fps = 24 total_frames = duration * fps diff --git a/tests/components/stream/conftest.py b/tests/components/stream/conftest.py index a73678d763f..746cc05fcbd 100644 --- a/tests/components/stream/conftest.py +++ b/tests/components/stream/conftest.py @@ -17,11 +17,12 @@ import logging import threading from unittest.mock import patch +from aiohttp import web import async_timeout import pytest from homeassistant.components.stream import Stream -from homeassistant.components.stream.core import Segment +from homeassistant.components.stream.core import Segment, StreamOutput TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout @@ -120,3 +121,95 @@ def record_worker_sync(hass): autospec=True, ): yield sync + + +class HLSSync: + """Test fixture that intercepts stream worker calls to StreamOutput.""" + + def __init__(self): + """Initialize HLSSync.""" + self._request_event = asyncio.Event() + self._original_recv = StreamOutput.recv + self._original_part_recv = StreamOutput.part_recv + self._original_bad_request = web.HTTPBadRequest + self._original_not_found = web.HTTPNotFound + self._original_response = web.Response + self._num_requests = 0 + self._num_recvs = 0 + self._num_finished = 0 + + def reset_request_pool(self, num_requests: int, reset_finished=True): + """Use to reset the request counter between segments.""" + self._num_recvs = 0 + if reset_finished: + self._num_finished = 0 + self._num_requests = num_requests + + async def wait_for_handler(self): + """Set up HLSSync to block calls to put until requests are set up.""" + if not self.check_requests_ready(): + await self._request_event.wait() + self.reset_request_pool(num_requests=self._num_requests, reset_finished=False) + + def check_requests_ready(self): + """Unblock the pending put call if the requests are all finished or blocking.""" + if self._num_recvs + self._num_finished == self._num_requests: + self._request_event.set() + self._request_event.clear() + return True + return False + + def bad_request(self): + """Intercept the HTTPBadRequest call so we know when the web handler is finished.""" + self._num_finished += 1 + self.check_requests_ready() + return self._original_bad_request() + + def not_found(self): + """Intercept the HTTPNotFound call so we know when the web handler is finished.""" + self._num_finished += 1 + self.check_requests_ready() + return self._original_not_found() + + def response(self, body, headers, status=200): + """Intercept the Response call so we know when the web handler is finished.""" + self._num_finished += 1 + self.check_requests_ready() + return self._original_response(body=body, headers=headers, status=status) + + async def recv(self, output: StreamOutput, **kw): + """Intercept the recv call so we know when the response is blocking on recv.""" + self._num_recvs += 1 + self.check_requests_ready() + return await self._original_recv(output) + + async def part_recv(self, output: StreamOutput, **kw): + """Intercept the recv call so we know when the response is blocking on recv.""" + self._num_recvs += 1 + self.check_requests_ready() + return await self._original_part_recv(output) + + +@pytest.fixture() +def hls_sync(): + """Patch HLSOutput to allow test to synchronize playlist requests and responses.""" + sync = HLSSync() + with patch( + "homeassistant.components.stream.core.StreamOutput.recv", + side_effect=sync.recv, + autospec=True, + ), patch( + "homeassistant.components.stream.core.StreamOutput.part_recv", + side_effect=sync.part_recv, + autospec=True, + ), patch( + "homeassistant.components.stream.hls.web.HTTPBadRequest", + side_effect=sync.bad_request, + ), patch( + "homeassistant.components.stream.hls.web.HTTPNotFound", + side_effect=sync.not_found, + ), patch( + "homeassistant.components.stream.hls.web.Response", + side_effect=sync.response, + ): + yield sync diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 919f71c8509..4b0cb0322ce 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -1,5 +1,5 @@ """The tests for hls streams.""" -from datetime import datetime, timedelta +from datetime import timedelta from unittest.mock import patch from urllib.parse import urlparse @@ -8,17 +8,23 @@ import pytest from homeassistant.components.stream import create_stream from homeassistant.components.stream.const import ( + EXT_X_START_LL_HLS, + EXT_X_START_NON_LL_HLS, HLS_PROVIDER, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS, ) -from homeassistant.components.stream.core import Part, Segment +from homeassistant.components.stream.core import Part from homeassistant.const import HTTP_NOT_FOUND from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from tests.common import async_fire_time_changed -from tests.components.stream.common import generate_h264_video +from tests.components.stream.common import ( + FAKE_TIME, + DefaultSegment as Segment, + generate_h264_video, +) STREAM_SOURCE = "some-stream-source" INIT_BYTES = b"init" @@ -26,7 +32,6 @@ FAKE_PAYLOAD = b"fake-payload" SEGMENT_DURATION = 10 TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever -FAKE_TIME = datetime.utcnow() class HlsClient: @@ -37,13 +42,13 @@ class HlsClient: self.http_client = http_client self.parsed_url = parsed_url - async def get(self, path=None): + async def get(self, path=None, headers=None): """Fetch the hls stream for the specified path.""" url = self.parsed_url.path if path: # Strip off the master playlist suffix and replace with path url = "/".join(self.parsed_url.path.split("/")[:-1]) + path - return await self.http_client.get(url) + return await self.http_client.get(url, headers=headers) @pytest.fixture @@ -60,36 +65,52 @@ def hls_stream(hass, hass_client): def make_segment(segment, discontinuity=False): """Create a playlist response for a segment.""" - response = [] - if discontinuity: - response.extend( - [ - "#EXT-X-DISCONTINUITY", - "#EXT-X-PROGRAM-DATE-TIME:" - + FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] - + "Z", - ] - ) - response.extend([f"#EXTINF:{SEGMENT_DURATION:.3f},", f"./segment/{segment}.m4s"]) + response = ["#EXT-X-DISCONTINUITY"] if discontinuity else [] + response.extend( + [ + "#EXT-X-PROGRAM-DATE-TIME:" + + FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z", + f"#EXTINF:{SEGMENT_DURATION:.3f},", + f"./segment/{segment}.m4s", + ] + ) return "\n".join(response) -def make_playlist(sequence, segments, discontinuity_sequence=0): +def make_playlist( + sequence, + discontinuity_sequence=0, + segments=None, + hint=None, + part_target_duration=None, +): """Create a an hls playlist response for tests to assert on.""" response = [ "#EXTM3U", "#EXT-X-VERSION:6", "#EXT-X-INDEPENDENT-SEGMENTS", '#EXT-X-MAP:URI="init.mp4"', - "#EXT-X-TARGETDURATION:10", + f"#EXT-X-TARGETDURATION:{SEGMENT_DURATION}", f"#EXT-X-MEDIA-SEQUENCE:{sequence}", f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}", - "#EXT-X-PROGRAM-DATE-TIME:" - + FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] - + "Z", - f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f}", ] - response.extend(segments) + if hint: + response.extend( + [ + f"#EXT-X-PART-INF:PART-TARGET={part_target_duration:.3f}", + f"#EXT-X-SERVER-CONTROL:CAN-BLOCK-RELOAD=YES,PART-HOLD-BACK={2*part_target_duration:.3f}", + f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START_LL_HLS*part_target_duration:.3f},PRECISE=YES", + ] + ) + else: + response.append( + f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START_NON_LL_HLS*SEGMENT_DURATION:.3f},PRECISE=YES", + ) + if segments: + response.extend(segments) + if hint: + response.append(hint) response.append("") return "\n".join(response) @@ -115,18 +136,23 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync): hls_client = await hls_stream(stream) - # Fetch playlist - playlist_response = await hls_client.get() - assert playlist_response.status == 200 + # Fetch master playlist + master_playlist_response = await hls_client.get() + assert master_playlist_response.status == 200 # Fetch init - playlist = await playlist_response.text() + master_playlist = await master_playlist_response.text() init_response = await hls_client.get("/init.mp4") assert init_response.status == 200 + # Fetch playlist + playlist_url = "/" + master_playlist.splitlines()[-1] + playlist_response = await hls_client.get(playlist_url) + assert playlist_response.status == 200 + # Fetch segment playlist = await playlist_response.text() - segment_url = "/" + playlist.splitlines()[-1] + segment_url = "/" + [line for line in playlist.splitlines() if line][-1] segment_response = await hls_client.get(segment_url) assert segment_response.status == 200 @@ -243,7 +269,7 @@ async def test_stream_keepalive(hass): stream.stop() -async def test_hls_playlist_view_no_output(hass, hass_client, hls_stream): +async def test_hls_playlist_view_no_output(hass, hls_stream): """Test rendering the hls playlist with no output segments.""" await async_setup_component(hass, "stream", {"stream": {}}) @@ -265,7 +291,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): stream_worker_sync.pause() hls = stream.add_provider(HLS_PROVIDER) for i in range(2): - segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME) + segment = Segment(sequence=i, duration=SEGMENT_DURATION) hls.put(segment) await hass.async_block_till_done() @@ -277,7 +303,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): sequence=0, segments=[make_segment(0), make_segment(1)] ) - segment = Segment(sequence=2, duration=SEGMENT_DURATION, start_time=FAKE_TIME) + segment = Segment(sequence=2, duration=SEGMENT_DURATION) hls.put(segment) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") @@ -302,9 +328,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): # Produce enough segments to overfill the output buffer by one for sequence in range(MAX_SEGMENTS + 1): - segment = Segment( - sequence=sequence, duration=SEGMENT_DURATION, start_time=FAKE_TIME - ) + segment = Segment(sequence=sequence, duration=SEGMENT_DURATION) hls.put(segment) await hass.async_block_till_done() @@ -321,16 +345,17 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): # Fetch the actual segments with a fake byte payload for segment in hls.get_segments(): segment.init = INIT_BYTES - segment.parts = [ - Part( + segment.parts_by_byterange = { + 0: Part( duration=SEGMENT_DURATION, has_keyframe=True, data=FAKE_PAYLOAD, ) - ] + } # The segment that fell off the buffer is not accessible - segment_response = await hls_client.get("/segment/0.m4s") + with patch.object(hls.stream_settings, "hls_part_timeout", 0.1): + segment_response = await hls_client.get("/segment/0.m4s") assert segment_response.status == 404 # However all segments in the buffer are accessible, even those that were not in the playlist. @@ -350,19 +375,14 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s stream_worker_sync.pause() hls = stream.add_provider(HLS_PROVIDER) - segment = Segment( - sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME - ) + segment = Segment(sequence=0, stream_id=0, duration=SEGMENT_DURATION) hls.put(segment) - segment = Segment( - sequence=1, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME - ) + segment = Segment(sequence=1, stream_id=0, duration=SEGMENT_DURATION) hls.put(segment) segment = Segment( sequence=2, stream_id=1, duration=SEGMENT_DURATION, - start_time=FAKE_TIME, ) hls.put(segment) await hass.async_block_till_done() @@ -394,9 +414,7 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy hls_client = await hls_stream(stream) - segment = Segment( - sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME - ) + segment = Segment(sequence=0, stream_id=0, duration=SEGMENT_DURATION) hls.put(segment) # Produce enough segments to overfill the output buffer by one @@ -405,7 +423,6 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy sequence=sequence, stream_id=1, duration=SEGMENT_DURATION, - start_time=FAKE_TIME, ) hls.put(segment) await hass.async_block_till_done() diff --git a/tests/components/stream/test_ll_hls.py b/tests/components/stream/test_ll_hls.py new file mode 100644 index 00000000000..8e512e0723e --- /dev/null +++ b/tests/components/stream/test_ll_hls.py @@ -0,0 +1,731 @@ +"""The tests for hls streams.""" +import asyncio +import itertools +import re +from urllib.parse import urlparse + +import pytest + +from homeassistant.components.stream import create_stream +from homeassistant.components.stream.const import ( + ATTR_SETTINGS, + CONF_LL_HLS, + CONF_PART_DURATION, + CONF_SEGMENT_DURATION, + DOMAIN, + HLS_PROVIDER, +) +from homeassistant.components.stream.core import Part +from homeassistant.const import HTTP_NOT_FOUND +from homeassistant.setup import async_setup_component + +from .test_hls import SEGMENT_DURATION, STREAM_SOURCE, HlsClient, make_playlist + +from tests.components.stream.common import ( + FAKE_TIME, + DefaultSegment as Segment, + generate_h264_video, +) + +TEST_PART_DURATION = 1 +NUM_PART_SEGMENTS = int(-(-SEGMENT_DURATION // TEST_PART_DURATION)) +PART_INDEPENDENT_PERIOD = int(1 / TEST_PART_DURATION) or 1 +BYTERANGE_LENGTH = 1 +INIT_BYTES = b"init" +SEQUENCE_BYTES = bytearray(range(NUM_PART_SEGMENTS * BYTERANGE_LENGTH)) +ALT_SEQUENCE_BYTES = bytearray(range(20, 20 + NUM_PART_SEGMENTS * BYTERANGE_LENGTH)) +VERY_LARGE_LAST_BYTE_POS = 9007199254740991 + + +@pytest.fixture +def hls_stream(hass, hass_client): + """Create test fixture for creating an HLS client for a stream.""" + + async def create_client_for_stream(stream): + stream.ll_hls = True + http_client = await hass_client() + parsed_url = urlparse(stream.endpoint_url(HLS_PROVIDER)) + return HlsClient(http_client, parsed_url) + + return create_client_for_stream + + +def create_segment(sequence): + """Create an empty segment.""" + segment = Segment(sequence=sequence) + segment.init = INIT_BYTES + return segment + + +def complete_segment(segment): + """Completes a segment by setting its duration.""" + segment.duration = sum( + part.duration for part in segment.parts_by_byterange.values() + ) + + +def create_parts(source): + """Create parts from a source.""" + independent_cycle = itertools.cycle( + [True] + [False] * (PART_INDEPENDENT_PERIOD - 1) + ) + return [ + Part( + duration=TEST_PART_DURATION, + has_keyframe=next(independent_cycle), + data=bytes(source[i * BYTERANGE_LENGTH : (i + 1) * BYTERANGE_LENGTH]), + ) + for i in range(NUM_PART_SEGMENTS) + ] + + +def http_range_from_part(part): + """Return dummy byterange (length, start) given part number.""" + return BYTERANGE_LENGTH, part * BYTERANGE_LENGTH + + +def make_segment_with_parts( + segment, num_parts, independent_period, discontinuity=False +): + """Create a playlist response for a segment including part segments.""" + response = [] + for i in range(num_parts): + length, start = http_range_from_part(i) + response.append( + f'#EXT-X-PART:DURATION={TEST_PART_DURATION:.3f},URI="./segment/{segment}.m4s",BYTERANGE="{length}@{start}"{",INDEPENDENT=YES" if i%independent_period==0 else ""}' + ) + if discontinuity: + response.append("#EXT-X-DISCONTINUITY") + response.extend( + [ + "#EXT-X-PROGRAM-DATE-TIME:" + + FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z", + f"#EXTINF:{SEGMENT_DURATION:.3f},", + f"./segment/{segment}.m4s", + ] + ) + return "\n".join(response) + + +def make_hint(segment, part): + """Create a playlist response for the preload hint.""" + _, start = http_range_from_part(part) + return f'#EXT-X-PRELOAD-HINT:TYPE=PART,URI="./segment/{segment}.m4s",BYTERANGE-START={start}' + + +async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync): + """ + Test hls stream. + + Purposefully not mocking anything here to test full + integration with the stream component. + """ + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream_worker_sync.pause() + + # Setup demo HLS track + source = generate_h264_video(duration=SEGMENT_DURATION + 1) + stream = create_stream(hass, source, {}) + + # Request stream + stream.add_provider(HLS_PROVIDER) + stream.start() + + hls_client = await hls_stream(stream) + + # Fetch playlist + master_playlist_response = await hls_client.get() + assert master_playlist_response.status == 200 + + # Fetch init + master_playlist = await master_playlist_response.text() + init_response = await hls_client.get("/init.mp4") + assert init_response.status == 200 + + # Fetch playlist + playlist_url = "/" + master_playlist.splitlines()[-1] + playlist_response = await hls_client.get(playlist_url) + assert playlist_response.status == 200 + + # Fetch segments + playlist = await playlist_response.text() + segment_re = re.compile(r"^(?P./segment/\d+\.m4s)") + for line in playlist.splitlines(): + match = segment_re.match(line) + if match: + segment_url = "/" + match.group("segment_url") + segment_response = await hls_client.get(segment_url) + assert segment_response.status == 200 + + def check_part_is_moof_mdat(data: bytes): + if len(data) < 8 or data[4:8] != b"moof": + return False + moof_length = int.from_bytes(data[0:4], byteorder="big") + if ( + len(data) < moof_length + 8 + or data[moof_length + 4 : moof_length + 8] != b"mdat" + ): + return False + mdat_length = int.from_bytes( + data[moof_length : moof_length + 4], byteorder="big" + ) + if mdat_length + moof_length != len(data): + return False + return True + + # Fetch all completed part segments + part_re = re.compile( + r'#EXT-X-PART:DURATION=[0-9].[0-9]{5,5},URI="(?P.+?)",BYTERANGE="(?P[0-9]+?)@(?P[0-9]+?)"(,INDEPENDENT=YES)?' + ) + for line in playlist.splitlines(): + match = part_re.match(line) + if match: + part_segment_url = "/" + match.group("part_url") + byterange_end = ( + int(match.group("byterange_length")) + + int(match.group("byterange_start")) + - 1 + ) + part_segment_response = await hls_client.get( + part_segment_url, + headers={ + "Range": f'bytes={match.group("byterange_start")}-{byterange_end}' + }, + ) + assert part_segment_response.status == 206 + assert check_part_is_moof_mdat(await part_segment_response.read()) + + stream_worker_sync.resume() + + # Stop stream, if it hasn't quit already + stream.stop() + + # Ensure playlist not accessible after stream ends + fail_response = await hls_client.get() + assert fail_response.status == HTTP_NOT_FOUND + + +async def test_ll_hls_playlist_view(hass, hls_stream, stream_worker_sync): + """Test rendering the hls playlist with 1 and 2 output segments.""" + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream = create_stream(hass, STREAM_SOURCE, {}) + stream_worker_sync.pause() + hls = stream.add_provider(HLS_PROVIDER) + + # Add 2 complete segments to output + for sequence in range(2): + segment = create_segment(sequence=sequence) + hls.put(segment) + for part in create_parts(SEQUENCE_BYTES): + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + await hass.async_block_till_done() + + hls_client = await hls_stream(stream) + + resp = await hls_client.get("/playlist.m3u8") + assert resp.status == 200 + assert await resp.text() == make_playlist( + sequence=0, + segments=[ + make_segment_with_parts( + i, len(segment.parts_by_byterange), PART_INDEPENDENT_PERIOD + ) + for i in range(2) + ], + hint=make_hint(2, 0), + part_target_duration=hls.stream_settings.part_target_duration, + ) + + # add one more segment + segment = create_segment(sequence=2) + hls.put(segment) + for part in create_parts(SEQUENCE_BYTES): + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + + await hass.async_block_till_done() + resp = await hls_client.get("/playlist.m3u8") + assert resp.status == 200 + assert await resp.text() == make_playlist( + sequence=0, + segments=[ + make_segment_with_parts( + i, len(segment.parts_by_byterange), PART_INDEPENDENT_PERIOD + ) + for i in range(3) + ], + hint=make_hint(3, 0), + part_target_duration=hls.stream_settings.part_target_duration, + ) + + stream_worker_sync.resume() + stream.stop() + + +async def test_ll_hls_msn(hass, hls_stream, stream_worker_sync, hls_sync): + """Test that requests using _HLS_msn get held and returned or rejected.""" + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream = create_stream(hass, STREAM_SOURCE, {}) + stream_worker_sync.pause() + + hls = stream.add_provider(HLS_PROVIDER) + + hls_client = await hls_stream(stream) + + # Create 4 requests for sequences 0 through 3 + # 0 and 1 should hold then go through and 2 and 3 should fail immediately. + + hls_sync.reset_request_pool(4) + msn_requests = asyncio.gather( + *(hls_client.get(f"/playlist.m3u8?_HLS_msn={i}") for i in range(4)) + ) + + for sequence in range(3): + await hls_sync.wait_for_handler() + segment = Segment(sequence=sequence, duration=SEGMENT_DURATION) + hls.put(segment) + + msn_responses = await msn_requests + + assert msn_responses[0].status == 200 + assert msn_responses[1].status == 200 + assert msn_responses[2].status == 400 + assert msn_responses[3].status == 400 + + # Sequence number is now 2. Create six more requests for sequences 0 through 5. + # Calls for msn 0 through 4 should work, 5 should fail. + + hls_sync.reset_request_pool(6) + msn_requests = asyncio.gather( + *(hls_client.get(f"/playlist.m3u8?_HLS_msn={i}") for i in range(6)) + ) + for sequence in range(3, 6): + await hls_sync.wait_for_handler() + segment = Segment(sequence=sequence, duration=SEGMENT_DURATION) + hls.put(segment) + + msn_responses = await msn_requests + assert msn_responses[0].status == 200 + assert msn_responses[1].status == 200 + assert msn_responses[2].status == 200 + assert msn_responses[3].status == 200 + assert msn_responses[4].status == 200 + assert msn_responses[5].status == 400 + + stream_worker_sync.resume() + + +async def test_ll_hls_playlist_bad_msn_part(hass, hls_stream, stream_worker_sync): + """Test some playlist requests with invalid _HLS_msn/_HLS_part.""" + + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream = create_stream(hass, STREAM_SOURCE, {}) + stream_worker_sync.pause() + + hls = stream.add_provider(HLS_PROVIDER) + + hls_client = await hls_stream(stream) + + # If the Playlist URI contains an _HLS_part directive but no _HLS_msn + # directive, the Server MUST return Bad Request, such as HTTP 400. + + assert (await hls_client.get("/playlist.m3u8?_HLS_part=1")).status == 400 + + # Seed hls with 1 complete segment and 1 in process segment + segment = create_segment(sequence=0) + hls.put(segment) + for part in create_parts(SEQUENCE_BYTES): + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + + segment = create_segment(sequence=1) + hls.put(segment) + remaining_parts = create_parts(SEQUENCE_BYTES) + num_completed_parts = len(remaining_parts) // 2 + for part in remaining_parts[:num_completed_parts]: + segment.async_add_part(part, 0) + + # If the _HLS_msn is greater than the Media Sequence Number of the last + # Media Segment in the current Playlist plus two, or if the _HLS_part + # exceeds the last Partial Segment in the current Playlist by the + # Advance Part Limit, then the server SHOULD immediately return Bad + # Request, such as HTTP 400. The Advance Part Limit is three divided + # by the Part Target Duration if the Part Target Duration is less than + # one second, or three otherwise. + + # Current sequence number is 1 and part number is num_completed_parts-1 + # The following two tests should fail immediately: + # - request with a _HLS_msn of 4 + # - request with a _HLS_msn of 1 and a _HLS_part of num_completed_parts-1+advance_part_limit + assert (await hls_client.get("/playlist.m3u8?_HLS_msn=4")).status == 400 + assert ( + await hls_client.get( + f"/playlist.m3u8?_HLS_msn=1&_HLS_part={num_completed_parts-1+hass.data[DOMAIN][ATTR_SETTINGS].hls_advance_part_limit}" + ) + ).status == 400 + stream_worker_sync.resume() + + +async def test_ll_hls_playlist_rollover_part( + hass, hls_stream, stream_worker_sync, hls_sync +): + """Test playlist request rollover.""" + + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream = create_stream(hass, STREAM_SOURCE, {}) + stream_worker_sync.pause() + + hls = stream.add_provider(HLS_PROVIDER) + + hls_client = await hls_stream(stream) + + # Seed hls with 1 complete segment and 1 in process segment + for sequence in range(2): + segment = create_segment(sequence=sequence) + hls.put(segment) + + for part in create_parts(SEQUENCE_BYTES): + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + + await hass.async_block_till_done() + + hls_sync.reset_request_pool(4) + segment = hls.get_segment(1) + # the first request corresponds to the last part of segment 1 + # the remaining requests correspond to part 0 of segment 2 + requests = asyncio.gather( + *( + [ + hls_client.get( + f"/playlist.m3u8?_HLS_msn=1&_HLS_part={len(segment.parts_by_byterange)-1}" + ), + hls_client.get( + f"/playlist.m3u8?_HLS_msn=1&_HLS_part={len(segment.parts_by_byterange)}" + ), + hls_client.get( + f"/playlist.m3u8?_HLS_msn=1&_HLS_part={len(segment.parts_by_byterange)+1}" + ), + hls_client.get("/playlist.m3u8?_HLS_msn=2&_HLS_part=0"), + ] + ) + ) + + await hls_sync.wait_for_handler() + + segment = create_segment(sequence=2) + hls.put(segment) + await hass.async_block_till_done() + + remaining_parts = create_parts(SEQUENCE_BYTES) + segment.async_add_part(remaining_parts.pop(0), 0) + hls.part_put() + + await hls_sync.wait_for_handler() + + different_response, *same_responses = await requests + + assert different_response.status == 200 + assert all(response.status == 200 for response in same_responses) + different_playlist = await different_response.read() + same_playlists = [await response.read() for response in same_responses] + assert different_playlist != same_playlists[0] + assert all(playlist == same_playlists[0] for playlist in same_playlists[1:]) + + stream_worker_sync.resume() + + +async def test_ll_hls_playlist_msn_part(hass, hls_stream, stream_worker_sync, hls_sync): + """Test that requests using _HLS_msn and _HLS_part get held and returned.""" + + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream = create_stream(hass, STREAM_SOURCE, {}) + stream_worker_sync.pause() + + hls = stream.add_provider(HLS_PROVIDER) + + hls_client = await hls_stream(stream) + + # Seed hls with 1 complete segment and 1 in process segment + segment = create_segment(sequence=0) + hls.put(segment) + for part in create_parts(SEQUENCE_BYTES): + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + + segment = create_segment(sequence=1) + hls.put(segment) + remaining_parts = create_parts(SEQUENCE_BYTES) + num_completed_parts = len(remaining_parts) // 2 + for part in remaining_parts[:num_completed_parts]: + segment.async_add_part(part, 0) + del remaining_parts[:num_completed_parts] + + # Make requests for all the part segments up to n+ADVANCE_PART_LIMIT + hls_sync.reset_request_pool( + num_completed_parts + + int(-(-hass.data[DOMAIN][ATTR_SETTINGS].hls_advance_part_limit // 1)) + ) + msn_requests = asyncio.gather( + *( + hls_client.get(f"/playlist.m3u8?_HLS_msn=1&_HLS_part={i}") + for i in range( + num_completed_parts + + int(-(-hass.data[DOMAIN][ATTR_SETTINGS].hls_advance_part_limit // 1)) + ) + ) + ) + + while remaining_parts: + await hls_sync.wait_for_handler() + segment.async_add_part(remaining_parts.pop(0), 0) + hls.part_put() + + msn_responses = await msn_requests + + # All the responses should succeed except the last one which fails + assert all(response.status == 200 for response in msn_responses[:-1]) + assert msn_responses[-1].status == 400 + + stream_worker_sync.resume() + + +async def test_get_part_segments(hass, hls_stream, stream_worker_sync, hls_sync): + """Test requests for part segments and hinted parts.""" + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) + + stream = create_stream(hass, STREAM_SOURCE, {}) + stream_worker_sync.pause() + + hls = stream.add_provider(HLS_PROVIDER) + + hls_client = await hls_stream(stream) + + # Seed hls with 1 complete segment and 1 in process segment + segment = create_segment(sequence=0) + hls.put(segment) + for part in create_parts(SEQUENCE_BYTES): + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + + segment = create_segment(sequence=1) + hls.put(segment) + remaining_parts = create_parts(SEQUENCE_BYTES) + num_completed_parts = len(remaining_parts) // 2 + for _ in range(num_completed_parts): + segment.async_add_part(remaining_parts.pop(0), 0) + + # Make requests for all the existing part segments + # These should succeed with a status of 206 + requests = asyncio.gather( + *( + hls_client.get( + "/segment/1.m4s", + headers={ + "Range": f"bytes={http_range_from_part(part)[1]}-" + + str( + http_range_from_part(part)[0] + + http_range_from_part(part)[1] + - 1 + ) + }, + ) + for part in range(num_completed_parts) + ) + ) + responses = await requests + assert all(response.status == 206 for response in responses) + assert all( + responses[part].headers["Content-Range"] + == f"bytes {http_range_from_part(part)[1]}-" + + str(http_range_from_part(part)[0] + http_range_from_part(part)[1] - 1) + + "/*" + for part in range(num_completed_parts) + ) + parts = list(segment.parts_by_byterange.values()) + assert all( + [await responses[i].read() == parts[i].data for i in range(len(responses))] + ) + + # Make some non standard range requests. + # Request past end of previous closed segment + # Request should succeed but length will be limited to the segment length + response = await hls_client.get( + "/segment/0.m4s", + headers={"Range": f"bytes=0-{hls.get_segment(0).data_size+1}"}, + ) + assert response.status == 206 + assert ( + response.headers["Content-Range"] + == f"bytes 0-{hls.get_segment(0).data_size-1}/{hls.get_segment(0).data_size}" + ) + assert (await response.read()) == hls.get_segment(0).get_data() + + # Request with start range past end of current segment + # Since this is beyond the data we have (the largest starting position will be + # from a hinted request, and even that will have a starting position at + # segment.data_size), we expect a 416. + response = await hls_client.get( + "/segment/1.m4s", + headers={"Range": f"bytes={segment.data_size+1}-{VERY_LARGE_LAST_BYTE_POS}"}, + ) + assert response.status == 416 + + # Request for next segment which has not yet been hinted (we will only hint + # for this segment after segment 1 is complete). + # This should fail, but it will hold for one more part_put before failing. + hls_sync.reset_request_pool(1) + request = asyncio.create_task( + hls_client.get( + "/segment/2.m4s", headers={"Range": f"bytes=0-{VERY_LARGE_LAST_BYTE_POS}"} + ) + ) + await hls_sync.wait_for_handler() + hls.part_put() + response = await request + assert response.status == 404 + + # Make valid request for the current hint. This should succeed, but since + # it is open ended, it won't finish until the segment is complete. + hls_sync.reset_request_pool(1) + request_start = segment.data_size + request = asyncio.create_task( + hls_client.get( + "/segment/1.m4s", + headers={"Range": f"bytes={request_start}-{VERY_LARGE_LAST_BYTE_POS}"}, + ) + ) + # Put the remaining parts and complete the segment + while remaining_parts: + await hls_sync.wait_for_handler() + # Put one more part segment + segment.async_add_part(remaining_parts.pop(0), 0) + hls.part_put() + complete_segment(segment) + # Check the response + response = await request + assert response.status == 206 + assert ( + response.headers["Content-Range"] + == f"bytes {request_start}-{VERY_LARGE_LAST_BYTE_POS}/*" + ) + assert await response.read() == SEQUENCE_BYTES[request_start:] + + # Now the hint should have moved to segment 2 + # The request for segment 2 which failed before should work now + # Also make an equivalent request with no Range parameters that + # will return the same content but with different headers + hls_sync.reset_request_pool(2) + requests = asyncio.gather( + hls_client.get( + "/segment/2.m4s", headers={"Range": f"bytes=0-{VERY_LARGE_LAST_BYTE_POS}"} + ), + hls_client.get("/segment/2.m4s"), + ) + # Put an entire segment and its parts. + segment = create_segment(sequence=2) + hls.put(segment) + remaining_parts = create_parts(ALT_SEQUENCE_BYTES) + for part in remaining_parts: + await hls_sync.wait_for_handler() + segment.async_add_part(part, 0) + hls.part_put() + complete_segment(segment) + # Check the response + responses = await requests + assert responses[0].status == 206 + assert ( + responses[0].headers["Content-Range"] == f"bytes 0-{VERY_LARGE_LAST_BYTE_POS}/*" + ) + assert responses[1].status == 200 + assert "Content-Range" not in responses[1].headers + assert ( + await response.read() == ALT_SEQUENCE_BYTES[: hls.get_segment(2).data_size] + for response in responses + ) + + stream_worker_sync.resume() diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index 31661db3886..b8521205920 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -9,7 +9,7 @@ import pytest from homeassistant.components.stream import create_stream from homeassistant.components.stream.const import HLS_PROVIDER, RECORDER_PROVIDER -from homeassistant.components.stream.core import Part, Segment +from homeassistant.components.stream.core import Part from homeassistant.components.stream.fmp4utils import find_box from homeassistant.components.stream.recorder import recorder_save_worker from homeassistant.exceptions import HomeAssistantError @@ -17,7 +17,11 @@ from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from tests.common import async_fire_time_changed -from tests.components.stream.common import generate_h264_video, remux_with_audio +from tests.components.stream.common import ( + DefaultSegment as Segment, + generate_h264_video, + remux_with_audio, +) MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever @@ -122,15 +126,14 @@ def add_parts_to_segment(segment, source): """Add relevant part data to segment for testing recorder.""" moof_locs = list(find_box(source.getbuffer(), b"moof")) + [len(source.getbuffer())] segment.init = source.getbuffer()[: moof_locs[0]].tobytes() - segment.parts = [ - Part( + segment.parts_by_byterange = { + moof_locs[i]: Part( duration=None, has_keyframe=None, - http_range_start=None, data=source.getbuffer()[moof_locs[i] : moof_locs[i + 1]], ) - for i in range(1, len(moof_locs) - 1) - ] + for i in range(len(moof_locs) - 1) + } async def test_recorder_save(tmpdir): @@ -219,7 +222,7 @@ async def test_record_stream_audio( stream_worker_sync.resume() result = av.open( - BytesIO(last_segment.init + last_segment.get_bytes_without_init()), + BytesIO(last_segment.init + last_segment.get_data()), "r", format="mp4", ) diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index e62a190d7be..16412b28468 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -21,18 +21,27 @@ import threading from unittest.mock import patch import av +import pytest from homeassistant.components.stream import Stream, create_stream from homeassistant.components.stream.const import ( + ATTR_SETTINGS, + CONF_LL_HLS, + CONF_PART_DURATION, + CONF_SEGMENT_DURATION, + DOMAIN, HLS_PROVIDER, MAX_MISSING_DTS, PACKETS_TO_WAIT_FOR_AUDIO, - TARGET_SEGMENT_DURATION, + SEGMENT_DURATION_ADJUSTER, + TARGET_SEGMENT_DURATION_NON_LL_HLS, ) +from homeassistant.components.stream.core import StreamSettings from homeassistant.components.stream.worker import SegmentBuffer, stream_worker from homeassistant.setup import async_setup_component from tests.components.stream.common import generate_h264_video +from tests.components.stream.test_ll_hls import TEST_PART_DURATION STREAM_SOURCE = "some-stream-source" # Formats here are arbitrary, not exercised by tests @@ -43,7 +52,8 @@ AUDIO_SAMPLE_RATE = 11025 KEYFRAME_INTERVAL = 1 # in seconds PACKET_DURATION = fractions.Fraction(1, VIDEO_FRAME_RATE) # in seconds SEGMENT_DURATION = ( - math.ceil(TARGET_SEGMENT_DURATION / KEYFRAME_INTERVAL) * KEYFRAME_INTERVAL + math.ceil(TARGET_SEGMENT_DURATION_NON_LL_HLS / KEYFRAME_INTERVAL) + * KEYFRAME_INTERVAL ) # in seconds TEST_SEQUENCE_LENGTH = 5 * VIDEO_FRAME_RATE LONGER_TEST_SEQUENCE_LENGTH = 20 * VIDEO_FRAME_RATE @@ -53,6 +63,21 @@ SEGMENTS_PER_PACKET = PACKET_DURATION / SEGMENT_DURATION TIMEOUT = 15 +@pytest.fixture(autouse=True) +def mock_stream_settings(hass): + """Set the stream settings data in hass before each test.""" + hass.data[DOMAIN] = { + ATTR_SETTINGS: StreamSettings( + ll_hls=False, + min_segment_duration=TARGET_SEGMENT_DURATION_NON_LL_HLS + - SEGMENT_DURATION_ADJUSTER, + part_target_duration=TARGET_SEGMENT_DURATION_NON_LL_HLS, + hls_advance_part_limit=3, + hls_part_timeout=TARGET_SEGMENT_DURATION_NON_LL_HLS, + ) + } + + class FakeAvInputStream: """A fake pyav Stream.""" @@ -235,7 +260,7 @@ async def async_decode_stream(hass, packets, py_av=None): "homeassistant.components.stream.core.StreamOutput.put", side_effect=py_av.capture_buffer.capture_output_segment, ): - segment_buffer = SegmentBuffer(stream.outputs) + segment_buffer = SegmentBuffer(hass, stream.outputs) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) await hass.async_block_till_done() @@ -248,7 +273,7 @@ async def test_stream_open_fails(hass): stream.add_provider(HLS_PROVIDER) with patch("av.open") as av_open: av_open.side_effect = av.error.InvalidDataError(-2, "error") - segment_buffer = SegmentBuffer(stream.outputs) + segment_buffer = SegmentBuffer(hass, stream.outputs) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) await hass.async_block_till_done() av_open.assert_called_once() @@ -638,7 +663,7 @@ async def test_worker_log(hass, caplog): stream.add_provider(HLS_PROVIDER) with patch("av.open") as av_open: av_open.side_effect = av.error.InvalidDataError(-2, "error") - segment_buffer = SegmentBuffer(stream.outputs) + segment_buffer = SegmentBuffer(hass, stream.outputs) stream_worker( "https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event() ) @@ -649,7 +674,17 @@ async def test_worker_log(hass, caplog): async def test_durations(hass, record_worker_sync): """Test that the duration metadata matches the media.""" - await async_setup_component(hass, "stream", {"stream": {}}) + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + CONF_PART_DURATION: TEST_PART_DURATION, + } + }, + ) source = generate_h264_video() stream = create_stream(hass, source, {}) @@ -664,7 +699,7 @@ async def test_durations(hass, record_worker_sync): # check that the Part duration metadata matches the durations in the media running_metadata_duration = 0 for segment in complete_segments: - for part in segment.parts: + for part in segment.parts_by_byterange.values(): av_part = av.open(io.BytesIO(segment.init + part.data)) running_metadata_duration += part.duration # av_part.duration will just return the largest dts in av_part. @@ -678,7 +713,9 @@ async def test_durations(hass, record_worker_sync): # check that the Part durations are consistent with the Segment durations for segment in complete_segments: assert math.isclose( - sum(part.duration for part in segment.parts), segment.duration, abs_tol=1e-6 + sum(part.duration for part in segment.parts_by_byterange.values()), + segment.duration, + abs_tol=1e-6, ) await record_worker_sync.join() @@ -688,7 +725,19 @@ async def test_durations(hass, record_worker_sync): async def test_has_keyframe(hass, record_worker_sync): """Test that the has_keyframe metadata matches the media.""" - await async_setup_component(hass, "stream", {"stream": {}}) + await async_setup_component( + hass, + "stream", + { + "stream": { + CONF_LL_HLS: True, + CONF_SEGMENT_DURATION: SEGMENT_DURATION, + # Our test video has keyframes every second. Use smaller parts so we have more + # part boundaries to better test keyframe logic. + CONF_PART_DURATION: 0.25, + } + }, + ) source = generate_h264_video() stream = create_stream(hass, source, {}) @@ -697,15 +746,12 @@ async def test_has_keyframe(hass, record_worker_sync): with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") - # Our test video has keyframes every second. Use smaller parts so we have more - # part boundaries to better test keyframe logic. - with patch("homeassistant.components.stream.worker.TARGET_PART_DURATION", 0.25): - complete_segments = list(await record_worker_sync.get_segments())[:-1] + complete_segments = list(await record_worker_sync.get_segments())[:-1] assert len(complete_segments) >= 1 # check that the Part has_keyframe metadata matches the keyframes in the media for segment in complete_segments: - for part in segment.parts: + for part in segment.parts_by_byterange.values(): av_part = av.open(io.BytesIO(segment.init + part.data)) media_has_keyframe = any( packet.is_keyframe for packet in av_part.demux(av_part.streams.video[0])