From e8b5790846480d73aa4ce6adbf88546fe5e5931e Mon Sep 17 00:00:00 2001 From: uvjustin <46082645+uvjustin@users.noreply.github.com> Date: Sun, 20 Jun 2021 13:38:02 +0800 Subject: [PATCH] Clean up stream refactor (#51951) * Clean up target_duration method * Consolidate Part creation in one place * Use BytesIO.read instead of memoryview access * Change flush() signature --- homeassistant/components/stream/core.py | 18 +++--- homeassistant/components/stream/fmp4utils.py | 2 +- homeassistant/components/stream/hls.py | 5 +- homeassistant/components/stream/worker.py | 66 +++++++++----------- tests/components/stream/test_hls.py | 1 - 5 files changed, 44 insertions(+), 48 deletions(-) diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 5f8bb736761..d840bfaf858 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers.event import async_call_later from homeassistant.util.decorator import Registry -from .const import ATTR_STREAMS, DOMAIN +from .const import ATTR_STREAMS, DOMAIN, TARGET_SEGMENT_DURATION if TYPE_CHECKING: from . import Stream @@ -28,6 +28,7 @@ class Part: duration: float = attr.ib() has_keyframe: bool = attr.ib() + # video data (moof+mdat) data: bytes = attr.ib() @@ -50,7 +51,7 @@ class Segment: return self.duration > 0 def get_bytes_without_init(self) -> bytes: - """Return reconstructed data for entire segment as bytes.""" + """Return reconstructed data for all parts as bytes, without init.""" return b"".join([part.data for part in self.parts]) @@ -141,17 +142,16 @@ class StreamOutput: return None @property - def target_duration(self) -> int: + def target_duration(self) -> float: """Return the max duration of any given segment in seconds.""" - segment_length = len(self._segments) - if not segment_length: - return 1 - durations = [s.duration for s in self._segments] - return round(max(durations)) or 1 + if not (durations := [s.duration for s in self._segments if s.complete]): + return TARGET_SEGMENT_DURATION + return max(durations) def get_segment(self, sequence: int) -> Segment | None: """Retrieve a specific segment.""" - for segment in self._segments: + # Most hits will come in the most recent segments, so iterate reversed + for segment in reversed(self._segments): if segment.sequence == sequence: return segment return None diff --git a/homeassistant/components/stream/fmp4utils.py b/homeassistant/components/stream/fmp4utils.py index ef01158be62..f136784cf87 100644 --- a/homeassistant/components/stream/fmp4utils.py +++ b/homeassistant/components/stream/fmp4utils.py @@ -5,7 +5,7 @@ from collections.abc import Generator def find_box( - mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0 + mp4_bytes: bytes, target_type: bytes, box_start: int = 0 ) -> Generator[int, None, None]: """Find location of first box (or sub_box if box_start provided) of given type.""" if box_start == 0: diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index d7167e0b7de..7f11bc09655 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -173,8 +173,9 @@ class HlsInitView(StreamView): track = stream.add_provider(HLS_PROVIDER) if not (segments := track.get_segments()): return web.HTTPNotFound() - headers = {"Content-Type": "video/mp4"} - return web.Response(body=segments[0].init, headers=headers) + return web.Response( + body=segments[0].init, headers={"Content-Type": "video/mp4"} + ) class HlsSegmentView(StreamView): diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 3023b8cd85c..04be79e668e 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -3,7 +3,6 @@ from __future__ import annotations from collections import deque from collections.abc import Iterator, Mapping -from fractions import Fraction from io import BytesIO import logging from threading import Event @@ -49,7 +48,8 @@ class SegmentBuffer: self._output_video_stream: av.video.VideoStream = None self._output_audio_stream: Any | None = None # av.audio.AudioStream | None self._segment: Segment | None = None - self._segment_last_write_pos: int = cast(int, None) + # the following 3 member variables are used for Part formation + self._memory_file_pos: int = cast(int, None) self._part_start_dts: int = cast(int, None) self._part_has_keyframe = False @@ -93,10 +93,10 @@ class SegmentBuffer: """Initialize a new stream segment.""" # Keep track of the number of segments we've processed self._sequence += 1 - self._segment_start_dts = self._part_start_dts = video_dts + self._segment_start_dts = video_dts self._segment = None - self._segment_last_write_pos = 0 self._memory_file = BytesIO() + self._memory_file_pos = 0 self._av_output = self.make_new_av( memory_file=self._memory_file, sequence=self._sequence, @@ -120,14 +120,11 @@ class SegmentBuffer: if ( packet.is_keyframe - and ( - segment_duration := (packet.dts - self._segment_start_dts) - * packet.time_base - ) + and (packet.dts - self._segment_start_dts) * packet.time_base >= MIN_SEGMENT_DURATION ): # Flush segment (also flushes the stub part segment) - self.flush(segment_duration, packet) + self.flush(packet, last_part=True) # Reinitialize self.reset(packet.dts) @@ -143,8 +140,7 @@ class SegmentBuffer: def check_flush_part(self, packet: av.Packet) -> None: """Check for and mark a part segment boundary and record its duration.""" - byte_position = self._memory_file.tell() - if self._segment_last_write_pos == byte_position: + if self._memory_file_pos == self._memory_file.tell(): return if self._segment is None: # We have our first non-zero byte position. This means the init has just @@ -154,43 +150,43 @@ class SegmentBuffer: stream_id=self._stream_id, init=self._memory_file.getvalue(), ) - self._segment_last_write_pos = byte_position + self._memory_file_pos = self._memory_file.tell() + self._part_start_dts = self._segment_start_dts # Fetch the latest StreamOutputs, which may have changed since the # worker started. for stream_output in self._outputs_callback().values(): stream_output.put(self._segment) else: # These are the ends of the part segments - self._segment.parts.append( - Part( - duration=float( - (packet.dts - self._part_start_dts) * packet.time_base - ), - has_keyframe=self._part_has_keyframe, - data=self._memory_file.getbuffer()[ - self._segment_last_write_pos : byte_position - ].tobytes(), - ) - ) - self._segment_last_write_pos = byte_position - self._part_start_dts = packet.dts - self._part_has_keyframe = False + self.flush(packet, last_part=False) - def flush(self, duration: Fraction, packet: av.Packet) -> None: - """Create a segment from the buffered packets and write to output.""" - self._av_output.close() + def flush(self, packet: av.Packet, last_part: bool) -> None: + """Output a part from the most recent bytes in the memory_file. + + If last_part is True, also close the segment, give it a duration, + and clean up the av_output and memory_file. + """ + if last_part: + # Closing the av_output will write the remaining buffered data to the + # memory_file as a new moof/mdat. + self._av_output.close() assert self._segment - self._segment.duration = float(duration) - # Also flush the part segment (need to close the output above before this) + self._memory_file.seek(self._memory_file_pos) self._segment.parts.append( Part( duration=float((packet.dts - self._part_start_dts) * packet.time_base), has_keyframe=self._part_has_keyframe, - data=self._memory_file.getbuffer()[ - self._segment_last_write_pos : - ].tobytes(), + data=self._memory_file.read(), ) ) - self._memory_file.close() # We don't need the BytesIO object anymore + if last_part: + self._segment.duration = float( + (packet.dts - self._segment_start_dts) * packet.time_base + ) + self._memory_file.close() # We don't need the BytesIO object anymore + else: + self._memory_file_pos = self._memory_file.tell() + self._part_start_dts = packet.dts + self._part_has_keyframe = False def discontinuity(self) -> None: """Mark the stream as having been restarted.""" diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 89c07083b17..919f71c8509 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -264,7 +264,6 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, STREAM_SOURCE, {}) stream_worker_sync.pause() hls = stream.add_provider(HLS_PROVIDER) - for i in range(2): segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME) hls.put(segment)