Clean up stream refactor (#51951)

* Clean up target_duration method

* Consolidate Part creation in one place

* Use BytesIO.read instead of memoryview access

* Change flush() signature
This commit is contained in:
uvjustin 2021-06-20 13:38:02 +08:00 committed by GitHub
parent 3836d46dff
commit e8b5790846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 48 deletions

View File

@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from .const import ATTR_STREAMS, DOMAIN from .const import ATTR_STREAMS, DOMAIN, TARGET_SEGMENT_DURATION
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Stream from . import Stream
@ -28,6 +28,7 @@ class Part:
duration: float = attr.ib() duration: float = attr.ib()
has_keyframe: bool = attr.ib() has_keyframe: bool = attr.ib()
# video data (moof+mdat)
data: bytes = attr.ib() data: bytes = attr.ib()
@ -50,7 +51,7 @@ class Segment:
return self.duration > 0 return self.duration > 0
def get_bytes_without_init(self) -> bytes: def get_bytes_without_init(self) -> bytes:
"""Return reconstructed data for entire segment as bytes.""" """Return reconstructed data for all parts as bytes, without init."""
return b"".join([part.data for part in self.parts]) return b"".join([part.data for part in self.parts])
@ -141,17 +142,16 @@ class StreamOutput:
return None return None
@property @property
def target_duration(self) -> int: def target_duration(self) -> float:
"""Return the max duration of any given segment in seconds.""" """Return the max duration of any given segment in seconds."""
segment_length = len(self._segments) if not (durations := [s.duration for s in self._segments if s.complete]):
if not segment_length: return TARGET_SEGMENT_DURATION
return 1 return max(durations)
durations = [s.duration for s in self._segments]
return round(max(durations)) or 1
def get_segment(self, sequence: int) -> Segment | None: def get_segment(self, sequence: int) -> Segment | None:
"""Retrieve a specific segment.""" """Retrieve a specific segment."""
for segment in self._segments: # Most hits will come in the most recent segments, so iterate reversed
for segment in reversed(self._segments):
if segment.sequence == sequence: if segment.sequence == sequence:
return segment return segment
return None return None

View File

@ -5,7 +5,7 @@ from collections.abc import Generator
def find_box( def find_box(
mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0 mp4_bytes: bytes, target_type: bytes, box_start: int = 0
) -> Generator[int, None, None]: ) -> Generator[int, None, None]:
"""Find location of first box (or sub_box if box_start provided) of given type.""" """Find location of first box (or sub_box if box_start provided) of given type."""
if box_start == 0: if box_start == 0:

View File

@ -173,8 +173,9 @@ class HlsInitView(StreamView):
track = stream.add_provider(HLS_PROVIDER) track = stream.add_provider(HLS_PROVIDER)
if not (segments := track.get_segments()): if not (segments := track.get_segments()):
return web.HTTPNotFound() return web.HTTPNotFound()
headers = {"Content-Type": "video/mp4"} return web.Response(
return web.Response(body=segments[0].init, headers=headers) body=segments[0].init, headers={"Content-Type": "video/mp4"}
)
class HlsSegmentView(StreamView): class HlsSegmentView(StreamView):

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from collections import deque from collections import deque
from collections.abc import Iterator, Mapping from collections.abc import Iterator, Mapping
from fractions import Fraction
from io import BytesIO from io import BytesIO
import logging import logging
from threading import Event from threading import Event
@ -49,7 +48,8 @@ class SegmentBuffer:
self._output_video_stream: av.video.VideoStream = None self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream: Any | None = None # av.audio.AudioStream | None self._output_audio_stream: Any | None = None # av.audio.AudioStream | None
self._segment: Segment | None = None self._segment: Segment | None = None
self._segment_last_write_pos: int = cast(int, None) # the following 3 member variables are used for Part formation
self._memory_file_pos: int = cast(int, None)
self._part_start_dts: int = cast(int, None) self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False self._part_has_keyframe = False
@ -93,10 +93,10 @@ class SegmentBuffer:
"""Initialize a new stream segment.""" """Initialize a new stream segment."""
# Keep track of the number of segments we've processed # Keep track of the number of segments we've processed
self._sequence += 1 self._sequence += 1
self._segment_start_dts = self._part_start_dts = video_dts self._segment_start_dts = video_dts
self._segment = None self._segment = None
self._segment_last_write_pos = 0
self._memory_file = BytesIO() self._memory_file = BytesIO()
self._memory_file_pos = 0
self._av_output = self.make_new_av( self._av_output = self.make_new_av(
memory_file=self._memory_file, memory_file=self._memory_file,
sequence=self._sequence, sequence=self._sequence,
@ -120,14 +120,11 @@ class SegmentBuffer:
if ( if (
packet.is_keyframe packet.is_keyframe
and ( and (packet.dts - self._segment_start_dts) * packet.time_base
segment_duration := (packet.dts - self._segment_start_dts)
* packet.time_base
)
>= MIN_SEGMENT_DURATION >= MIN_SEGMENT_DURATION
): ):
# Flush segment (also flushes the stub part segment) # Flush segment (also flushes the stub part segment)
self.flush(segment_duration, packet) self.flush(packet, last_part=True)
# Reinitialize # Reinitialize
self.reset(packet.dts) self.reset(packet.dts)
@ -143,8 +140,7 @@ class SegmentBuffer:
def check_flush_part(self, packet: av.Packet) -> None: def check_flush_part(self, packet: av.Packet) -> None:
"""Check for and mark a part segment boundary and record its duration.""" """Check for and mark a part segment boundary and record its duration."""
byte_position = self._memory_file.tell() if self._memory_file_pos == self._memory_file.tell():
if self._segment_last_write_pos == byte_position:
return return
if self._segment is None: if self._segment is None:
# We have our first non-zero byte position. This means the init has just # We have our first non-zero byte position. This means the init has just
@ -154,43 +150,43 @@ class SegmentBuffer:
stream_id=self._stream_id, stream_id=self._stream_id,
init=self._memory_file.getvalue(), init=self._memory_file.getvalue(),
) )
self._segment_last_write_pos = byte_position self._memory_file_pos = self._memory_file.tell()
self._part_start_dts = self._segment_start_dts
# Fetch the latest StreamOutputs, which may have changed since the # Fetch the latest StreamOutputs, which may have changed since the
# worker started. # worker started.
for stream_output in self._outputs_callback().values(): for stream_output in self._outputs_callback().values():
stream_output.put(self._segment) stream_output.put(self._segment)
else: # These are the ends of the part segments else: # These are the ends of the part segments
self._segment.parts.append( self.flush(packet, last_part=False)
Part(
duration=float(
(packet.dts - self._part_start_dts) * packet.time_base
),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos : byte_position
].tobytes(),
)
)
self._segment_last_write_pos = byte_position
self._part_start_dts = packet.dts
self._part_has_keyframe = False
def flush(self, duration: Fraction, packet: av.Packet) -> None: def flush(self, packet: av.Packet, last_part: bool) -> None:
"""Create a segment from the buffered packets and write to output.""" """Output a part from the most recent bytes in the memory_file.
If last_part is True, also close the segment, give it a duration,
and clean up the av_output and memory_file.
"""
if last_part:
# Closing the av_output will write the remaining buffered data to the
# memory_file as a new moof/mdat.
self._av_output.close() self._av_output.close()
assert self._segment assert self._segment
self._segment.duration = float(duration) self._memory_file.seek(self._memory_file_pos)
# Also flush the part segment (need to close the output above before this)
self._segment.parts.append( self._segment.parts.append(
Part( Part(
duration=float((packet.dts - self._part_start_dts) * packet.time_base), duration=float((packet.dts - self._part_start_dts) * packet.time_base),
has_keyframe=self._part_has_keyframe, has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[ data=self._memory_file.read(),
self._segment_last_write_pos :
].tobytes(),
) )
) )
if last_part:
self._segment.duration = float(
(packet.dts - self._segment_start_dts) * packet.time_base
)
self._memory_file.close() # We don't need the BytesIO object anymore self._memory_file.close() # We don't need the BytesIO object anymore
else:
self._memory_file_pos = self._memory_file.tell()
self._part_start_dts = packet.dts
self._part_has_keyframe = False
def discontinuity(self) -> None: def discontinuity(self) -> None:
"""Mark the stream as having been restarted.""" """Mark the stream as having been restarted."""

View File

@ -264,7 +264,6 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream = create_stream(hass, STREAM_SOURCE, {}) stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER) hls = stream.add_provider(HLS_PROVIDER)
for i in range(2): for i in range(2):
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME) segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
hls.put(segment) hls.put(segment)