Refactor stream to create partial segments (#51282)

This commit is contained in:
uvjustin 2021-06-14 00:41:21 +08:00 committed by GitHub
parent 1adeb82930
commit 123e8f01a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 499 additions and 305 deletions

View File

@ -18,8 +18,9 @@ FORMAT_CONTENT_TYPE = {HLS_PROVIDER: "application/vnd.apple.mpegurl"}
OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity
NUM_PLAYLIST_SEGMENTS = 3 # Number of segments to use in HLS playlist NUM_PLAYLIST_SEGMENTS = 3 # Number of segments to use in HLS playlist
MAX_SEGMENTS = 4 # Max number of segments to keep around MAX_SEGMENTS = 5 # Max number of segments to keep around
TARGET_SEGMENT_DURATION = 2.0 # Each segment is about this many seconds TARGET_SEGMENT_DURATION = 2.0 # Each segment is about this many seconds
TARGET_PART_DURATION = 1.0
SEGMENT_DURATION_ADJUSTER = 0.1 # Used to avoid missing keyframe boundaries SEGMENT_DURATION_ADJUSTER = 0.1 # Used to avoid missing keyframe boundaries
# Each segment is at least this many seconds # Each segment is at least this many seconds
MIN_SEGMENT_DURATION = TARGET_SEGMENT_DURATION - SEGMENT_DURATION_ADJUSTER MIN_SEGMENT_DURATION = TARGET_SEGMENT_DURATION - SEGMENT_DURATION_ADJUSTER

View File

@ -19,20 +19,37 @@ from .const import ATTR_STREAMS, DOMAIN
PROVIDERS = Registry() PROVIDERS = Registry()
@attr.s(slots=True)
class Part:
"""Represent a segment part."""
duration: float = attr.ib()
has_keyframe: bool = attr.ib()
data: bytes = attr.ib()
@attr.s(slots=True) @attr.s(slots=True)
class Segment: class Segment:
"""Represent a segment.""" """Represent a segment."""
sequence: int = attr.ib() sequence: int = attr.ib(default=0)
# the init of the mp4 # the init of the mp4 the segment is based on
init: bytes = attr.ib() init: bytes = attr.ib(default=None)
# the video data (moof + mddat)s of the mp4 duration: float = attr.ib(default=0)
moof_data: bytes = attr.ib()
duration: float = attr.ib()
# For detecting discontinuities across stream restarts # For detecting discontinuities across stream restarts
stream_id: int = attr.ib(default=0) stream_id: int = attr.ib(default=0)
parts: list[Part] = attr.ib(factory=list)
start_time: datetime.datetime = attr.ib(factory=datetime.datetime.utcnow) start_time: datetime.datetime = attr.ib(factory=datetime.datetime.utcnow)
@property
def complete(self) -> bool:
"""Return whether the Segment is complete."""
return self.duration > 0
def get_bytes_without_init(self) -> bytes:
"""Return reconstructed data for entire segment as bytes."""
return b"".join([part.data for part in self.parts])
class IdleTimer: class IdleTimer:
"""Invoke a callback after an inactivity timeout. """Invoke a callback after an inactivity timeout.

View File

@ -25,16 +25,6 @@ def find_box(
index += int.from_bytes(box_header[0:4], byteorder="big") index += int.from_bytes(box_header[0:4], byteorder="big")
def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]:
"""Get the init and moof data from a segment."""
moof_location = next(find_box(segment, b"moof"), 0)
mfra_location = next(find_box(segment, b"mfra"), len(segment))
return (
segment[:moof_location].tobytes(),
segment[moof_location:mfra_location].tobytes(),
)
def get_codec_string(mp4_bytes: bytes) -> str: def get_codec_string(mp4_bytes: bytes) -> str:
"""Get RFC 6381 codec string.""" """Get RFC 6381 codec string."""
codecs = [] codecs = []

View File

@ -37,9 +37,12 @@ class HlsMasterPlaylistView(StreamView):
# Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work # Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work
# Calculate file size / duration and use a small multiplier to account for variation # Calculate file size / duration and use a small multiplier to account for variation
# hls spec already allows for 25% variation # hls spec already allows for 25% variation
segment = track.get_segment(track.sequences[-1]) segment = track.get_segment(track.sequences[-2])
bandwidth = round( bandwidth = round(
(len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2 (len(segment.init) + sum(len(part.data) for part in segment.parts))
* 8
/ segment.duration
* 1.2
) )
codecs = get_codec_string(segment.init) codecs = get_codec_string(segment.init)
lines = [ lines = [
@ -53,9 +56,11 @@ class HlsMasterPlaylistView(StreamView):
"""Return m3u8 playlist.""" """Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
stream.start() stream.start()
# Wait for a segment to be ready # Make sure at least two segments are ready (last one may not be complete)
if not track.sequences and not await track.recv(): if not track.sequences and not await track.recv():
return web.HTTPNotFound() return web.HTTPNotFound()
if len(track.sequences) == 1 and not await track.recv():
return web.HTTPNotFound()
headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]} headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]}
return web.Response(body=self.render(track).encode("utf-8"), headers=headers) return web.Response(body=self.render(track).encode("utf-8"), headers=headers)
@ -68,38 +73,44 @@ class HlsPlaylistView(StreamView):
cors_allowed = True cors_allowed = True
@staticmethod @staticmethod
def render_preamble(track): def render(track):
"""Render preamble."""
return [
"#EXT-X-VERSION:6",
f"#EXT-X-TARGETDURATION:{track.target_duration}",
'#EXT-X-MAP:URI="init.mp4"',
]
@staticmethod
def render_playlist(track):
"""Render playlist.""" """Render playlist."""
segments = list(track.get_segments())[-NUM_PLAYLIST_SEGMENTS:] # NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete
segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :]
if not segments: # To cap the number of complete segments at NUM_PLAYLIST_SEGMENTS,
return [] # remove the first segment if the last segment is actually complete
if segments[-1].complete:
segments = segments[-NUM_PLAYLIST_SEGMENTS:]
first_segment = segments[0] first_segment = segments[0]
playlist = [ playlist = [
"#EXTM3U",
"#EXT-X-VERSION:6",
"#EXT-X-INDEPENDENT-SEGMENTS",
'#EXT-X-MAP:URI="init.mp4"',
f"#EXT-X-TARGETDURATION:{track.target_duration:.0f}",
f"#EXT-X-MEDIA-SEQUENCE:{first_segment.sequence}", f"#EXT-X-MEDIA-SEQUENCE:{first_segment.sequence}",
f"#EXT-X-DISCONTINUITY-SEQUENCE:{first_segment.stream_id}", f"#EXT-X-DISCONTINUITY-SEQUENCE:{first_segment.stream_id}",
"#EXT-X-PROGRAM-DATE-TIME:" "#EXT-X-PROGRAM-DATE-TIME:"
+ first_segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + first_segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z", + "Z",
# Since our window doesn't have many segments, we don't want to start # Since our window doesn't have many segments, we don't want to start
# at the beginning or we risk a behind live window exception in exoplayer. # at the beginning or we risk a behind live window exception in Exoplayer.
# EXT-X-START is not supposed to be within 3 target durations of the end, # EXT-X-START is not supposed to be within 3 target durations of the end,
# but this seems ok # but a value as low as 1.5 doesn't seem to hurt.
f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f},PRECISE=YES", # A value below 3 may not be as useful for hls.js as many hls.js clients
# don't autoplay. Also, hls.js uses the player parameter liveSyncDuration
# which seems to take precedence for setting target delay. Yet it also
# doesn't seem to hurt, so we can stick with it for now.
f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f}",
] ]
last_stream_id = first_segment.stream_id last_stream_id = first_segment.stream_id
# Add playlist sections
for segment in segments: for segment in segments:
# Skip last segment if it is not complete
if segment.complete:
if last_stream_id != segment.stream_id: if last_stream_id != segment.stream_id:
playlist.extend( playlist.extend(
[ [
@ -111,26 +122,23 @@ class HlsPlaylistView(StreamView):
) )
playlist.extend( playlist.extend(
[ [
f"#EXTINF:{float(segment.duration):.04f},", f"#EXTINF:{segment.duration:.3f},",
f"./segment/{segment.sequence}.m4s", f"./segment/{segment.sequence}.m4s",
] ]
) )
last_stream_id = segment.stream_id last_stream_id = segment.stream_id
return playlist return "\n".join(playlist) + "\n"
def render(self, track):
"""Render M3U8 file."""
lines = ["#EXTM3U"] + self.render_preamble(track) + self.render_playlist(track)
return "\n".join(lines) + "\n"
async def handle(self, request, stream, sequence): async def handle(self, request, stream, sequence):
"""Return m3u8 playlist.""" """Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
stream.start() stream.start()
# Wait for a segment to be ready # Make sure at least two segments are ready (last one may not be complete)
if not track.sequences and not await track.recv(): if not track.sequences and not await track.recv():
return web.HTTPNotFound() return web.HTTPNotFound()
if len(track.sequences) == 1 and not await track.recv():
return web.HTTPNotFound()
headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]} headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]}
response = web.Response( response = web.Response(
body=self.render(track).encode("utf-8"), headers=headers body=self.render(track).encode("utf-8"), headers=headers
@ -170,7 +178,7 @@ class HlsSegmentView(StreamView):
return web.HTTPNotFound() return web.HTTPNotFound()
headers = {"Content-Type": "video/iso.segment"} headers = {"Content-Type": "video/iso.segment"}
return web.Response( return web.Response(
body=segment.moof_data, body=segment.get_bytes_without_init(),
headers=headers, headers=headers,
) )

View File

@ -57,7 +57,7 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]):
# Open segment # Open segment
source = av.open( source = av.open(
BytesIO(segment.init + segment.moof_data), BytesIO(segment.init + segment.get_bytes_without_init()),
"r", "r",
format=SEGMENT_CONTAINER_FORMAT, format=SEGMENT_CONTAINER_FORMAT,
) )

View File

@ -2,9 +2,12 @@
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import deque
from collections.abc import Iterator, Mapping
from fractions import Fraction
from io import BytesIO from io import BytesIO
import logging import logging
from typing import cast from threading import Event
from typing import Callable, cast
import av import av
@ -17,9 +20,9 @@ from .const import (
PACKETS_TO_WAIT_FOR_AUDIO, PACKETS_TO_WAIT_FOR_AUDIO,
SEGMENT_CONTAINER_FORMAT, SEGMENT_CONTAINER_FORMAT,
SOURCE_TIMEOUT, SOURCE_TIMEOUT,
TARGET_PART_DURATION,
) )
from .core import Segment, StreamOutput from .core import Part, Segment, StreamOutput
from .fmp4utils import get_init_and_moof_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -27,22 +30,28 @@ _LOGGER = logging.getLogger(__name__)
class SegmentBuffer: class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment.""" """Buffer for writing a sequence of packets to the output as a segment."""
def __init__(self, outputs_callback) -> None: def __init__(
self, outputs_callback: Callable[[], Mapping[str, StreamOutput]]
) -> None:
"""Initialize SegmentBuffer.""" """Initialize SegmentBuffer."""
self._stream_id = 0 self._stream_id: int = 0
self._outputs_callback = outputs_callback self._outputs_callback: Callable[
self._outputs: list[StreamOutput] = [] [], Mapping[str, StreamOutput]
] = outputs_callback
# sequence gets incremented before the first segment so the first segment # sequence gets incremented before the first segment so the first segment
# has a sequence number of 0. # has a sequence number of 0.
self._sequence = -1 self._sequence = -1
self._segment_start_pts = None self._segment_start_dts: int = cast(int, None)
self._memory_file: BytesIO = cast(BytesIO, None) self._memory_file: BytesIO = cast(BytesIO, None)
self._av_output: av.container.OutputContainer = None self._av_output: av.container.OutputContainer = None
self._input_video_stream: av.video.VideoStream = None self._input_video_stream: av.video.VideoStream = None
self._input_audio_stream = None # av.audio.AudioStream | None self._input_audio_stream = None # av.audio.AudioStream | None
self._output_video_stream: av.video.VideoStream = None self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream = None # av.audio.AudioStream | None self._output_audio_stream = None # av.audio.AudioStream | None
self._segment: Segment = cast(Segment, None) self._segment: Segment | None = None
self._segment_last_write_pos: int = cast(int, None)
self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False
@staticmethod @staticmethod
def make_new_av( def make_new_av(
@ -56,10 +65,17 @@ class SegmentBuffer:
container_options={ container_options={
# Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970 # Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970
# "cmaf" flag replaces several of the movflags used, but too recent to use for now # "cmaf" flag replaces several of the movflags used, but too recent to use for now
"movflags": "frag_custom+empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer", "movflags": "empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer",
"avoid_negative_ts": "disabled", # Sometimes the first segment begins with negative timestamps, and this setting just
# adjusts the timestamps in the output from that segment to start from 0. Helps from
# having to make some adjustments in test_durations
"avoid_negative_ts": "make_non_negative",
"fragment_index": str(sequence + 1), "fragment_index": str(sequence + 1),
"video_track_timescale": str(int(1 / input_vstream.time_base)), "video_track_timescale": str(int(1 / input_vstream.time_base)),
# Create a fragments every TARGET_PART_DURATION. The data from each fragment is stored in
# a "Part" that can be combined with the data from all the other "Part"s, plus an init
# section, to reconstitute the data in a "Segment".
"frag_duration": str(int(TARGET_PART_DURATION * 1e6)),
}, },
) )
@ -73,15 +89,13 @@ class SegmentBuffer:
self._input_video_stream = video_stream self._input_video_stream = video_stream
self._input_audio_stream = audio_stream self._input_audio_stream = audio_stream
def reset(self, video_pts): def reset(self, video_dts: int) -> None:
"""Initialize a new stream segment.""" """Initialize a new stream segment."""
# Keep track of the number of segments we've processed # Keep track of the number of segments we've processed
self._sequence += 1 self._sequence += 1
self._segment_start_pts = video_pts self._segment_start_dts = self._part_start_dts = video_dts
self._segment = None
# Fetch the latest StreamOutputs, which may have changed since the self._segment_last_write_pos = 0
# worker started.
self._outputs = self._outputs_callback().values()
self._memory_file = BytesIO() self._memory_file = BytesIO()
self._av_output = self.make_new_av( self._av_output = self.make_new_av(
memory_file=self._memory_file, memory_file=self._memory_file,
@ -98,54 +112,102 @@ class SegmentBuffer:
template=self._input_audio_stream template=self._input_audio_stream
) )
def mux_packet(self, packet): def mux_packet(self, packet: av.Packet) -> None:
"""Mux a packet to the appropriate output stream.""" """Mux a packet to the appropriate output stream."""
# Check for end of segment # Check for end of segment
if packet.stream == self._input_video_stream and packet.is_keyframe: if packet.stream == self._input_video_stream:
duration = (packet.pts - self._segment_start_pts) * packet.time_base
if duration >= MIN_SEGMENT_DURATION:
# Save segment to outputs
self.flush(duration)
if (
packet.is_keyframe
and (
segment_duration := (packet.dts - self._segment_start_dts)
* packet.time_base
)
>= MIN_SEGMENT_DURATION
):
# Flush segment (also flushes the stub part segment)
self.flush(segment_duration, packet)
# Reinitialize # Reinitialize
self.reset(packet.pts) self.reset(packet.dts)
# Mux the packet # Mux the packet
if packet.stream == self._input_video_stream:
packet.stream = self._output_video_stream packet.stream = self._output_video_stream
self._av_output.mux(packet) self._av_output.mux(packet)
self.check_flush_part(packet)
self._part_has_keyframe |= packet.is_keyframe
elif packet.stream == self._input_audio_stream: elif packet.stream == self._input_audio_stream:
packet.stream = self._output_audio_stream packet.stream = self._output_audio_stream
self._av_output.mux(packet) self._av_output.mux(packet)
def flush(self, duration): def check_flush_part(self, packet: av.Packet) -> None:
"""Check for and mark a part segment boundary and record its duration."""
byte_position = self._memory_file.tell()
if self._segment_last_write_pos == byte_position:
return
if self._segment is None:
# We have our first non-zero byte position. This means the init has just
# been written. Create a Segment and put it to the queue of each output.
self._segment = Segment(
sequence=self._sequence,
stream_id=self._stream_id,
init=self._memory_file.getvalue(),
)
self._segment_last_write_pos = byte_position
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
for stream_output in self._outputs_callback().values():
stream_output.put(self._segment)
else: # These are the ends of the part segments
self._segment.parts.append(
Part(
duration=float(
(packet.dts - self._part_start_dts) * packet.time_base
),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos : byte_position
].tobytes(),
)
)
self._segment_last_write_pos = byte_position
self._part_start_dts = packet.dts
self._part_has_keyframe = False
def flush(self, duration: Fraction, packet: av.Packet) -> None:
"""Create a segment from the buffered packets and write to output.""" """Create a segment from the buffered packets and write to output."""
self._av_output.close() self._av_output.close()
segment = Segment( assert self._segment
self._sequence, self._segment.duration = float(duration)
*get_init_and_moof_data(self._memory_file.getbuffer()), # Also flush the part segment (need to close the output above before this)
duration, self._segment.parts.append(
self._stream_id, Part(
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos :
].tobytes(),
) )
self._memory_file.close() )
for stream_output in self._outputs: self._memory_file.close() # We don't need the BytesIO object anymore
stream_output.put(segment)
def discontinuity(self): def discontinuity(self) -> None:
"""Mark the stream as having been restarted.""" """Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic # Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine # simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number. # the discontinuity sequence number.
self._stream_id += 1 self._stream_id += 1
def close(self): def close(self) -> None:
"""Close stream buffer.""" """Close stream buffer."""
self._av_output.close() self._av_output.close()
self._memory_file.close() self._memory_file.close()
def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 def stream_worker( # noqa: C901
source: str, options: dict, segment_buffer: SegmentBuffer, quit_event: Event
) -> None:
"""Handle consuming streams.""" """Handle consuming streams."""
try: try:
@ -172,27 +234,27 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
audio_stream = None audio_stream = None
# Iterator for demuxing # Iterator for demuxing
container_packets = None container_packets: Iterator[av.Packet]
# The decoder timestamps of the latest packet in each stream we processed # The decoder timestamps of the latest packet in each stream we processed
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")} last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
# Keep track of consecutive packets without a dts to detect end of stream. # Keep track of consecutive packets without a dts to detect end of stream.
missing_dts = 0 missing_dts = 0
# The video pts at the beginning of the segment # The video dts at the beginning of the segment
segment_start_pts = None segment_start_dts: int | None = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them # Because of problems 1 and 2 below, we need to store the first few packets and replay them
initial_packets = deque() initial_packets: deque[av.Packet] = deque()
# Have to work around two problems with RTSP feeds in ffmpeg # Have to work around two problems with RTSP feeds in ffmpeg
# 1 - first frame has bad pts/dts https://trac.ffmpeg.org/ticket/5018 # 1 - first frame has bad pts/dts https://trac.ffmpeg.org/ticket/5018
# 2 - seeking can be problematic https://trac.ffmpeg.org/ticket/7815 # 2 - seeking can be problematic https://trac.ffmpeg.org/ticket/7815
def peek_first_pts(): def peek_first_dts() -> bool:
"""Initialize by peeking into the first few packets of the stream. """Initialize by peeking into the first few packets of the stream.
Deal with problem #1 above (bad first packet pts/dts) by recalculating using pts/dts from second packet. Deal with problem #1 above (bad first packet pts/dts) by recalculating using pts/dts from second packet.
Also load the first video keyframe pts into segment_start_pts and check if the audio stream really exists. Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
""" """
nonlocal segment_start_pts, audio_stream, container_packets nonlocal segment_start_dts, audio_stream, container_packets
missing_dts = 0 missing_dts = 0
found_audio = False found_audio = False
try: try:
@ -215,8 +277,8 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
elif packet.is_keyframe: # video_keyframe elif packet.is_keyframe: # video_keyframe
first_packet = packet first_packet = packet
initial_packets.append(packet) initial_packets.append(packet)
# Get first_pts from subsequent frame to first keyframe # Get first_dts from subsequent frame to first keyframe
while segment_start_pts is None or ( while segment_start_dts is None or (
audio_stream audio_stream
and not found_audio and not found_audio
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
@ -244,11 +306,10 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
continue continue
found_audio = True found_audio = True
elif ( elif (
segment_start_pts is None segment_start_dts is None
): # This is the second video frame to calculate first_pts from ): # This is the second video frame to calculate first_dts from
segment_start_pts = packet.dts - packet.duration segment_start_dts = packet.dts - packet.duration
first_packet.pts = segment_start_pts first_packet.pts = first_packet.dts = segment_start_dts
first_packet.dts = segment_start_pts
initial_packets.append(packet) initial_packets.append(packet)
if audio_stream and not found_audio: if audio_stream and not found_audio:
_LOGGER.warning( _LOGGER.warning(
@ -263,12 +324,13 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
return False return False
return True return True
if not peek_first_pts(): if not peek_first_dts():
container.close() container.close()
return return
segment_buffer.set_streams(video_stream, audio_stream) segment_buffer.set_streams(video_stream, audio_stream)
segment_buffer.reset(segment_start_pts) assert isinstance(segment_start_dts, int)
segment_buffer.reset(segment_start_dts)
while not quit_event.is_set(): while not quit_event.is_set():
try: try:

View File

@ -9,13 +9,21 @@ nothing for the test to verify. The solution is the WorkerSync class that
allows the tests to pause the worker thread before finalizing the stream allows the tests to pause the worker thread before finalizing the stream
so that it can inspect the output. so that it can inspect the output.
""" """
from __future__ import annotations
import asyncio
from collections import deque
import logging import logging
import threading import threading
from unittest.mock import patch from unittest.mock import patch
import async_timeout
import pytest import pytest
from homeassistant.components.stream import Stream from homeassistant.components.stream import Stream
from homeassistant.components.stream.core import Segment
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
class WorkerSync: class WorkerSync:
@ -58,3 +66,57 @@ def stream_worker_sync(hass):
autospec=True, autospec=True,
): ):
yield sync yield sync
class SaveRecordWorkerSync:
"""
Test fixture to manage RecordOutput thread for recorder_save_worker.
This is used to assert that the worker is started and stopped cleanly
to avoid thread leaks in tests.
"""
def __init__(self):
"""Initialize SaveRecordWorkerSync."""
self._save_event = None
self._segments = None
self._save_thread = None
self.reset()
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
assert self._save_thread is None
self._segments = segments
self._save_thread = threading.current_thread()
self._save_event.set()
async def get_segments(self):
"""Return the recorded video segments."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
return self._segments
async def join(self):
"""Verify save worker was invoked and block on shutdown."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
self._save_thread.join(timeout=TEST_TIMEOUT)
assert not self._save_thread.is_alive()
def reset(self):
"""Reset callback state for reuse in tests."""
self._save_thread = None
self._save_event = asyncio.Event()
@pytest.fixture()
def record_worker_sync(hass):
"""Patch recorder_save_worker for clean thread shutdown for test."""
sync = SaveRecordWorkerSync()
with patch(
"homeassistant.components.stream.recorder.recorder_save_worker",
side_effect=sync.recorder_save_worker,
autospec=True,
):
yield sync

View File

@ -12,7 +12,7 @@ from homeassistant.components.stream.const import (
MAX_SEGMENTS, MAX_SEGMENTS,
NUM_PLAYLIST_SEGMENTS, NUM_PLAYLIST_SEGMENTS,
) )
from homeassistant.components.stream.core import Segment from homeassistant.components.stream.core import Part, Segment
from homeassistant.const import HTTP_NOT_FOUND from homeassistant.const import HTTP_NOT_FOUND
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -22,7 +22,7 @@ from tests.components.stream.common import generate_h264_video
STREAM_SOURCE = "some-stream-source" STREAM_SOURCE = "some-stream-source"
INIT_BYTES = b"init" INIT_BYTES = b"init"
MOOF_BYTES = b"some-bytes" FAKE_PAYLOAD = b"fake-payload"
SEGMENT_DURATION = 10 SEGMENT_DURATION = 10
TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
@ -70,23 +70,24 @@ def make_segment(segment, discontinuity=False):
+ "Z", + "Z",
] ]
) )
response.extend(["#EXTINF:10.0000,", f"./segment/{segment}.m4s"]), response.extend([f"#EXTINF:{SEGMENT_DURATION:.3f},", f"./segment/{segment}.m4s"])
return "\n".join(response) return "\n".join(response)
def make_playlist(sequence, discontinuity_sequence=0, segments=[]): def make_playlist(sequence, segments, discontinuity_sequence=0):
"""Create a an hls playlist response for tests to assert on.""" """Create a an hls playlist response for tests to assert on."""
response = [ response = [
"#EXTM3U", "#EXTM3U",
"#EXT-X-VERSION:6", "#EXT-X-VERSION:6",
"#EXT-X-TARGETDURATION:10", "#EXT-X-INDEPENDENT-SEGMENTS",
'#EXT-X-MAP:URI="init.mp4"', '#EXT-X-MAP:URI="init.mp4"',
"#EXT-X-TARGETDURATION:10",
f"#EXT-X-MEDIA-SEQUENCE:{sequence}", f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}", f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}",
"#EXT-X-PROGRAM-DATE-TIME:" "#EXT-X-PROGRAM-DATE-TIME:"
+ FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z", + "Z",
f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f},PRECISE=YES", f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f}",
] ]
response.extend(segments) response.extend(segments)
response.append("") response.append("")
@ -264,21 +265,26 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)
hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME)) for i in range(2):
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
hls.put(segment)
await hass.async_block_till_done() await hass.async_block_till_done()
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200 assert resp.status == 200
assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)]) assert await resp.text() == make_playlist(
sequence=0, segments=[make_segment(0), make_segment(1)]
)
hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME)) segment = Segment(sequence=2, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
hls.put(segment)
await hass.async_block_till_done() await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200 assert resp.status == 200
assert await resp.text() == make_playlist( assert await resp.text() == make_playlist(
sequence=1, segments=[make_segment(1), make_segment(2)] sequence=0, segments=[make_segment(0), make_segment(1), make_segment(2)]
) )
stream_worker_sync.resume() stream_worker_sync.resume()
@ -296,37 +302,40 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
# Produce enough segments to overfill the output buffer by one # Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2): for sequence in range(MAX_SEGMENTS + 1):
hls.put( segment = Segment(
Segment( sequence=sequence, duration=SEGMENT_DURATION, start_time=FAKE_TIME
sequence,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
start_time=FAKE_TIME,
)
) )
hls.put(segment)
await hass.async_block_till_done() await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200 assert resp.status == 200
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist. # Only NUM_PLAYLIST_SEGMENTS are returned in the playlist.
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS
segments = [] segments = []
for sequence in range(start, MAX_SEGMENTS + 2): for sequence in range(start, MAX_SEGMENTS + 1):
segments.append(make_segment(sequence)) segments.append(make_segment(sequence))
assert await resp.text() == make_playlist( assert await resp.text() == make_playlist(sequence=start, segments=segments)
sequence=start,
segments=segments, # Fetch the actual segments with a fake byte payload
for segment in hls.get_segments():
segment.init = INIT_BYTES
segment.parts = [
Part(
duration=SEGMENT_DURATION,
has_keyframe=True,
data=FAKE_PAYLOAD,
) )
]
# The segment that fell off the buffer is not accessible # The segment that fell off the buffer is not accessible
segment_response = await hls_client.get("/segment/1.m4s") segment_response = await hls_client.get("/segment/0.m4s")
assert segment_response.status == 404 assert segment_response.status == 404
# However all segments in the buffer are accessible, even those that were not in the playlist. # However all segments in the buffer are accessible, even those that were not in the playlist.
for sequence in range(2, MAX_SEGMENTS + 2): for sequence in range(1, MAX_SEGMENTS + 1):
segment_response = await hls_client.get(f"/segment/{sequence}.m4s") segment_response = await hls_client.get(f"/segment/{sequence}.m4s")
assert segment_response.status == 200 assert segment_response.status == 200
@ -342,36 +351,21 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)
hls.put( segment = Segment(
Segment( sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
1,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=0,
start_time=FAKE_TIME,
) )
hls.put(segment)
segment = Segment(
sequence=1, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
) )
hls.put( hls.put(segment)
Segment( segment = Segment(
2, sequence=2,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=0,
start_time=FAKE_TIME,
)
)
hls.put(
Segment(
3,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=1, stream_id=1,
duration=SEGMENT_DURATION,
start_time=FAKE_TIME, start_time=FAKE_TIME,
) )
) hls.put(segment)
await hass.async_block_till_done() await hass.async_block_till_done()
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
@ -379,11 +373,11 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200 assert resp.status == 200
assert await resp.text() == make_playlist( assert await resp.text() == make_playlist(
sequence=1, sequence=0,
segments=[ segments=[
make_segment(0),
make_segment(1), make_segment(1),
make_segment(2), make_segment(2, discontinuity=True),
make_segment(3, discontinuity=True),
], ],
) )
@ -401,29 +395,20 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
hls.put( segment = Segment(
Segment( sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
1,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=0,
start_time=FAKE_TIME,
)
) )
hls.put(segment)
# Produce enough segments to overfill the output buffer by one # Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2): for sequence in range(MAX_SEGMENTS + 1):
hls.put( segment = Segment(
Segment( sequence=sequence,
sequence,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=1, stream_id=1,
duration=SEGMENT_DURATION,
start_time=FAKE_TIME, start_time=FAKE_TIME,
) )
) hls.put(segment)
await hass.async_block_till_done() await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
@ -432,9 +417,9 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the # Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the
# EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE # EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE
# returned instead. # returned instead.
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS
segments = [] segments = []
for sequence in range(start, MAX_SEGMENTS + 2): for sequence in range(start, MAX_SEGMENTS + 1):
segments.append(make_segment(sequence)) segments.append(make_segment(sequence))
assert await resp.text() == make_playlist( assert await resp.text() == make_playlist(
sequence=start, sequence=start,

View File

@ -1,23 +1,16 @@
"""The tests for hls streams.""" """The tests for hls streams."""
from __future__ import annotations
import asyncio
from collections import deque
from datetime import timedelta from datetime import timedelta
from io import BytesIO from io import BytesIO
import logging
import os import os
import threading
from unittest.mock import patch from unittest.mock import patch
import async_timeout
import av import av
import pytest import pytest
from homeassistant.components.stream import create_stream from homeassistant.components.stream import create_stream
from homeassistant.components.stream.const import HLS_PROVIDER, RECORDER_PROVIDER from homeassistant.components.stream.const import HLS_PROVIDER, RECORDER_PROVIDER
from homeassistant.components.stream.core import Segment from homeassistant.components.stream.core import Part, Segment
from homeassistant.components.stream.fmp4utils import get_init_and_moof_data from homeassistant.components.stream.fmp4utils import find_box
from homeassistant.components.stream.recorder import recorder_save_worker from homeassistant.components.stream.recorder import recorder_save_worker
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -26,63 +19,9 @@ import homeassistant.util.dt as dt_util
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
from tests.components.stream.common import generate_h264_video from tests.components.stream.common import generate_h264_video
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
class SaveRecordWorkerSync:
"""
Test fixture to manage RecordOutput thread for recorder_save_worker.
This is used to assert that the worker is started and stopped cleanly
to avoid thread leaks in tests.
"""
def __init__(self):
"""Initialize SaveRecordWorkerSync."""
self.reset()
self._segments = None
self._save_thread = None
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
assert self._save_thread is None
self._segments = segments
self._save_thread = threading.current_thread()
self._save_event.set()
async def get_segments(self):
"""Return the recorded video segments."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
return self._segments
async def join(self):
"""Verify save worker was invoked and block on shutdown."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
self._save_thread.join(timeout=TEST_TIMEOUT)
assert not self._save_thread.is_alive()
def reset(self):
"""Reset callback state for reuse in tests."""
self._save_thread = None
self._save_event = asyncio.Event()
@pytest.fixture()
def record_worker_sync(hass):
"""Patch recorder_save_worker for clean thread shutdown for test."""
sync = SaveRecordWorkerSync()
with patch(
"homeassistant.components.stream.recorder.recorder_save_worker",
side_effect=sync.recorder_save_worker,
autospec=True,
):
yield sync
async def test_record_stream(hass, hass_client, record_worker_sync): async def test_record_stream(hass, hass_client, record_worker_sync):
""" """
Test record stream. Test record stream.
@ -179,6 +118,21 @@ async def test_record_path_not_allowed(hass, hass_client):
await stream.async_record("/example/path") await stream.async_record("/example/path")
def add_parts_to_segment(segment, source):
"""Add relevant part data to segment for testing recorder."""
moof_locs = list(find_box(source.getbuffer(), b"moof")) + [len(source.getbuffer())]
segment.init = source.getbuffer()[: moof_locs[0]].tobytes()
segment.parts = [
Part(
duration=None,
has_keyframe=None,
http_range_start=None,
data=source.getbuffer()[moof_locs[i] : moof_locs[i + 1]],
)
for i in range(1, len(moof_locs) - 1)
]
async def test_recorder_save(tmpdir): async def test_recorder_save(tmpdir):
"""Test recorder save.""" """Test recorder save."""
# Setup # Setup
@ -186,9 +140,10 @@ async def test_recorder_save(tmpdir):
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
# Run # Run
recorder_save_worker( segment = Segment(sequence=1)
filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)] add_parts_to_segment(segment, source)
) segment.duration = 4
recorder_save_worker(filename, [segment])
# Assert # Assert
assert os.path.exists(filename) assert os.path.exists(filename)
@ -201,15 +156,13 @@ async def test_recorder_discontinuity(tmpdir):
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
# Run # Run
init, moof_data = get_init_and_moof_data(source.getbuffer()) segment_1 = Segment(sequence=1, stream_id=0)
recorder_save_worker( add_parts_to_segment(segment_1, source)
filename, segment_1.duration = 4
[ segment_2 = Segment(sequence=2, stream_id=1)
Segment(1, init, moof_data, 4, 0), add_parts_to_segment(segment_2, source)
Segment(2, init, moof_data, 4, 1), segment_2.duration = 4
], recorder_save_worker(filename, [segment_1, segment_2])
)
# Assert # Assert
assert os.path.exists(filename) assert os.path.exists(filename)
@ -263,7 +216,9 @@ async def test_record_stream_audio(
stream_worker_sync.resume() stream_worker_sync.resume()
result = av.open( result = av.open(
BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4" BytesIO(last_segment.init + last_segment.get_bytes_without_init()),
"r",
format="mp4",
) )
assert len(result.streams.audio) == expected_audio_streams assert len(result.streams.audio) == expected_audio_streams

View File

@ -21,7 +21,7 @@ from unittest.mock import patch
import av import av
from homeassistant.components.stream import Stream from homeassistant.components.stream import Stream, create_stream
from homeassistant.components.stream.const import ( from homeassistant.components.stream.const import (
HLS_PROVIDER, HLS_PROVIDER,
MAX_MISSING_DTS, MAX_MISSING_DTS,
@ -29,6 +29,9 @@ from homeassistant.components.stream.const import (
TARGET_SEGMENT_DURATION, TARGET_SEGMENT_DURATION,
) )
from homeassistant.components.stream.worker import SegmentBuffer, stream_worker from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
from homeassistant.setup import async_setup_component
from tests.components.stream.common import generate_h264_video
STREAM_SOURCE = "some-stream-source" STREAM_SOURCE = "some-stream-source"
# Formats here are arbitrary, not exercised by tests # Formats here are arbitrary, not exercised by tests
@ -99,9 +102,9 @@ class PacketSequence:
super().__init__(3) super().__init__(3)
time_base = fractions.Fraction(1, VIDEO_FRAME_RATE) time_base = fractions.Fraction(1, VIDEO_FRAME_RATE)
dts = self.packet * PACKET_DURATION / time_base dts = int(self.packet * PACKET_DURATION / time_base)
pts = self.packet * PACKET_DURATION / time_base pts = int(self.packet * PACKET_DURATION / time_base)
duration = PACKET_DURATION / time_base duration = int(PACKET_DURATION / time_base)
stream = VIDEO_STREAM stream = VIDEO_STREAM
# Pretend we get 1 keyframe every second # Pretend we get 1 keyframe every second
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL) is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
@ -177,6 +180,11 @@ class FakePyAvBuffer:
"""Capture the output segment for tests to inspect.""" """Capture the output segment for tests to inspect."""
self.segments.append(segment) self.segments.append(segment)
@property
def complete_segments(self):
"""Return only the complete segments."""
return [segment for segment in self.segments if segment.complete]
class MockPyAv: class MockPyAv:
"""Mocks out av.open.""" """Mocks out av.open."""
@ -197,6 +205,19 @@ class MockPyAv:
return self.container return self.container
class MockFlushPart:
"""Class to hold a wrapper function for check_flush_part."""
# Wrap this method with a preceding write so the BytesIO pointer moves
check_flush_part = SegmentBuffer.check_flush_part
@classmethod
def wrapped_check_flush_part(cls, segment_buffer, packet):
"""Wrap check_flush_part to also advance the memory_file pointer."""
segment_buffer._memory_file.write(b"0")
return cls.check_flush_part(segment_buffer, packet)
async def async_decode_stream(hass, packets, py_av=None): async def async_decode_stream(hass, packets, py_av=None):
"""Start a stream worker that decodes incoming stream packets into output segments.""" """Start a stream worker that decodes incoming stream packets into output segments."""
stream = Stream(hass, STREAM_SOURCE) stream = Stream(hass, STREAM_SOURCE)
@ -209,6 +230,10 @@ async def async_decode_stream(hass, packets, py_av=None):
with patch("av.open", new=py_av.open), patch( with patch("av.open", new=py_av.open), patch(
"homeassistant.components.stream.core.StreamOutput.put", "homeassistant.components.stream.core.StreamOutput.put",
side_effect=py_av.capture_buffer.capture_output_segment, side_effect=py_av.capture_buffer.capture_output_segment,
), patch(
"homeassistant.components.stream.worker.SegmentBuffer.check_flush_part",
side_effect=MockFlushPart.wrapped_check_flush_part,
autospec=True,
): ):
segment_buffer = SegmentBuffer(stream.outputs) segment_buffer = SegmentBuffer(stream.outputs)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
@ -235,13 +260,16 @@ async def test_stream_worker_success(hass):
hass, PacketSequence(TEST_SEQUENCE_LENGTH) hass, PacketSequence(TEST_SEQUENCE_LENGTH)
) )
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments. A segment is only formed when a packet from the next # Check number of segments. A segment is only formed when a packet from the next
# segment arrives, hence the subtraction of one from the sequence length. # segment arrives, hence the subtraction of one from the sequence length.
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int(
(TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations # Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments) assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == TEST_SEQUENCE_LENGTH assert len(decoded_stream.video_packets) == TEST_SEQUENCE_LENGTH
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -259,6 +287,7 @@ async def test_skip_out_of_order_packet(hass):
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# If skipped packet would have been the first packet of a segment, the previous # If skipped packet would have been the first packet of a segment, the previous
@ -273,12 +302,14 @@ async def test_skip_out_of_order_packet(hass):
) )
del segments[longer_segment_index] del segments[longer_segment_index]
# Check number of segments # Check number of segments
assert len(segments) == int((len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1) assert len(complete_segments) == int(
(len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1
)
else: # Otherwise segment durations and number of segments are unaffected else: # Otherwise segment durations and number of segments are unaffected
# Check number of segments # Check number of segments
assert len(segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET)
# Check remaining segment durations # Check remaining segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments) assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == len(packets) - 1 assert len(decoded_stream.video_packets) == len(packets) - 1
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -292,12 +323,15 @@ async def test_discard_old_packets(hass):
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int(
(OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations # Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments) assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -311,12 +345,15 @@ async def test_packet_overflow(hass):
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int(
(OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations # Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments) assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -332,10 +369,11 @@ async def test_skip_initial_bad_packets(hass):
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations # Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments) assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert ( assert (
len(decoded_stream.video_packets) len(decoded_stream.video_packets)
== num_packets == num_packets
@ -344,7 +382,7 @@ async def test_skip_initial_bad_packets(hass):
* KEYFRAME_INTERVAL * KEYFRAME_INTERVAL
) )
# Check number of segments # Check number of segments
assert len(segments) == int( assert len(complete_segments) == int(
(len(decoded_stream.video_packets) - 1) * SEGMENTS_PER_PACKET (len(decoded_stream.video_packets) - 1) * SEGMENTS_PER_PACKET
) )
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -381,13 +419,11 @@ async def test_skip_missing_dts(hass):
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations (not counting the last segment) # Check segment durations (not counting the last segment)
assert ( assert sum(segment.duration for segment in complete_segments) >= len(segments) - 1
sum([segments[i].duration == SEGMENT_DURATION for i in range(len(segments))])
>= len(segments) - 1
)
assert len(decoded_stream.video_packets) == num_packets - num_bad_packets assert len(decoded_stream.video_packets) == num_packets - num_bad_packets
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -403,8 +439,8 @@ async def test_too_many_bad_packets(hass):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments
assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == bad_packet_start assert len(decoded_stream.video_packets) == bad_packet_start
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -431,8 +467,8 @@ async def test_audio_packets_not_found(hass):
packets = PacketSequence(num_packets) # Contains only video packets packets = PacketSequence(num_packets) # Contains only video packets
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments
assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets assert len(decoded_stream.video_packets) == num_packets
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -444,8 +480,8 @@ async def test_adts_aac_audio(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = list(PacketSequence(num_packets)) packets = list(PacketSequence(num_packets))
packets[1].stream = AUDIO_STREAM packets[1].stream = AUDIO_STREAM
packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
# The following is packet data is a sign of ADTS AAC # The following is packet data is a sign of ADTS AAC
packets[1][0] = 255 packets[1][0] = 255
packets[1][1] = 241 packets[1][1] = 241
@ -462,17 +498,17 @@ async def test_audio_is_first_packet(hass):
packets = list(PacketSequence(num_packets)) packets = list(PacketSequence(num_packets))
# Pair up an audio packet for each video packet # Pair up an audio packet for each video packet
packets[0].stream = AUDIO_STREAM packets[0].stream = AUDIO_STREAM
packets[0].dts = packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[0].pts = packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1 packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
packets[2].stream = AUDIO_STREAM packets[2].stream = AUDIO_STREAM
packets[2].dts = packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments
# The audio packets are segmented with the video packets # The audio packets are segmented with the video packets
assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets - 2 assert len(decoded_stream.video_packets) == num_packets - 2
assert len(decoded_stream.audio_packets) == 1 assert len(decoded_stream.audio_packets) == 1
@ -484,13 +520,13 @@ async def test_audio_packets_found(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = list(PacketSequence(num_packets)) packets = list(PacketSequence(num_packets))
packets[1].stream = AUDIO_STREAM packets[1].stream = AUDIO_STREAM
packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments
# The audio packet above is buffered with the video packet # The audio packet above is buffered with the video packet
assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets - 1 assert len(decoded_stream.video_packets) == num_packets - 1
assert len(decoded_stream.audio_packets) == 1 assert len(decoded_stream.audio_packets) == 1
@ -507,12 +543,15 @@ async def test_pts_out_of_order(hass):
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int(
(TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers # Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments))) assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations # Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments) assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == len(packets) assert len(decoded_stream.video_packets) == len(packets)
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
@ -573,7 +612,11 @@ async def test_update_stream_source(hass):
worker_wake.wait() worker_wake.wait()
return py_av.open(stream_source, args, kwargs) return py_av.open(stream_source, args, kwargs)
with patch("av.open", new=blocking_open): with patch("av.open", new=blocking_open), patch(
"homeassistant.components.stream.worker.SegmentBuffer.check_flush_part",
side_effect=MockFlushPart.wrapped_check_flush_part,
autospec=True,
):
stream.start() stream.start()
assert worker_open.wait(TIMEOUT) assert worker_open.wait(TIMEOUT)
assert last_stream_source == STREAM_SOURCE assert last_stream_source == STREAM_SOURCE
@ -604,3 +647,74 @@ async def test_worker_log(hass, caplog):
await hass.async_block_till_done() await hass.async_block_till_done()
assert "https://abcd:efgh@foo.bar" not in caplog.text assert "https://abcd:efgh@foo.bar" not in caplog.text
assert "https://****:****@foo.bar" in caplog.text assert "https://****:****@foo.bar" in caplog.text
async def test_durations(hass, record_worker_sync):
"""Test that the duration metadata matches the media."""
await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video()
stream = create_stream(hass, source)
# use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
complete_segments = list(await record_worker_sync.get_segments())[:-1]
assert len(complete_segments) >= 1
# check that the Part duration metadata matches the durations in the media
running_metadata_duration = 0
for segment in complete_segments:
for part in segment.parts:
av_part = av.open(io.BytesIO(segment.init + part.data))
running_metadata_duration += part.duration
# av_part.duration will just return the largest dts in av_part.
# When we normalize by av.time_base this should equal the running duration
assert math.isclose(
running_metadata_duration,
av_part.duration / av.time_base,
abs_tol=1e-6,
)
av_part.close()
# check that the Part durations are consistent with the Segment durations
for segment in complete_segments:
assert math.isclose(
sum(part.duration for part in segment.parts), segment.duration, abs_tol=1e-6
)
await record_worker_sync.join()
stream.stop()
async def test_has_keyframe(hass, record_worker_sync):
"""Test that the has_keyframe metadata matches the media."""
await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video()
stream = create_stream(hass, source)
# use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
# Our test video has keyframes every second. Use smaller parts so we have more
# part boundaries to better test keyframe logic.
with patch("homeassistant.components.stream.worker.TARGET_PART_DURATION", 0.25):
complete_segments = list(await record_worker_sync.get_segments())[:-1]
assert len(complete_segments) >= 1
# check that the Part has_keyframe metadata matches the keyframes in the media
for segment in complete_segments:
for part in segment.parts:
av_part = av.open(io.BytesIO(segment.init + part.data))
media_has_keyframe = any(
packet.is_keyframe for packet in av_part.demux(av_part.streams.video[0])
)
av_part.close()
assert part.has_keyframe == media_has_keyframe
await record_worker_sync.join()
stream.stop()