Clean up stream refactor (#51951)

* Clean up target_duration method

* Consolidate Part creation in one place

* Use BytesIO.read instead of memoryview access

* Change flush() signature
This commit is contained in:
uvjustin 2021-06-20 13:38:02 +08:00 committed by GitHub
parent 3836d46dff
commit e8b5790846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 48 deletions

View File

@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.util.decorator import Registry
from .const import ATTR_STREAMS, DOMAIN
from .const import ATTR_STREAMS, DOMAIN, TARGET_SEGMENT_DURATION
if TYPE_CHECKING:
from . import Stream
@ -28,6 +28,7 @@ class Part:
duration: float = attr.ib()
has_keyframe: bool = attr.ib()
# video data (moof+mdat)
data: bytes = attr.ib()
@ -50,7 +51,7 @@ class Segment:
return self.duration > 0
def get_bytes_without_init(self) -> bytes:
"""Return reconstructed data for entire segment as bytes."""
"""Return reconstructed data for all parts as bytes, without init."""
return b"".join([part.data for part in self.parts])
@ -141,17 +142,16 @@ class StreamOutput:
return None
@property
def target_duration(self) -> int:
def target_duration(self) -> float:
"""Return the max duration of any given segment in seconds."""
segment_length = len(self._segments)
if not segment_length:
return 1
durations = [s.duration for s in self._segments]
return round(max(durations)) or 1
if not (durations := [s.duration for s in self._segments if s.complete]):
return TARGET_SEGMENT_DURATION
return max(durations)
def get_segment(self, sequence: int) -> Segment | None:
"""Retrieve a specific segment."""
for segment in self._segments:
# Most hits will come in the most recent segments, so iterate reversed
for segment in reversed(self._segments):
if segment.sequence == sequence:
return segment
return None

View File

@ -5,7 +5,7 @@ from collections.abc import Generator
def find_box(
mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0
mp4_bytes: bytes, target_type: bytes, box_start: int = 0
) -> Generator[int, None, None]:
"""Find location of first box (or sub_box if box_start provided) of given type."""
if box_start == 0:

View File

@ -173,8 +173,9 @@ class HlsInitView(StreamView):
track = stream.add_provider(HLS_PROVIDER)
if not (segments := track.get_segments()):
return web.HTTPNotFound()
headers = {"Content-Type": "video/mp4"}
return web.Response(body=segments[0].init, headers=headers)
return web.Response(
body=segments[0].init, headers={"Content-Type": "video/mp4"}
)
class HlsSegmentView(StreamView):

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from collections import deque
from collections.abc import Iterator, Mapping
from fractions import Fraction
from io import BytesIO
import logging
from threading import Event
@ -49,7 +48,8 @@ class SegmentBuffer:
self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream: Any | None = None # av.audio.AudioStream | None
self._segment: Segment | None = None
self._segment_last_write_pos: int = cast(int, None)
# the following 3 member variables are used for Part formation
self._memory_file_pos: int = cast(int, None)
self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False
@ -93,10 +93,10 @@ class SegmentBuffer:
"""Initialize a new stream segment."""
# Keep track of the number of segments we've processed
self._sequence += 1
self._segment_start_dts = self._part_start_dts = video_dts
self._segment_start_dts = video_dts
self._segment = None
self._segment_last_write_pos = 0
self._memory_file = BytesIO()
self._memory_file_pos = 0
self._av_output = self.make_new_av(
memory_file=self._memory_file,
sequence=self._sequence,
@ -120,14 +120,11 @@ class SegmentBuffer:
if (
packet.is_keyframe
and (
segment_duration := (packet.dts - self._segment_start_dts)
* packet.time_base
)
and (packet.dts - self._segment_start_dts) * packet.time_base
>= MIN_SEGMENT_DURATION
):
# Flush segment (also flushes the stub part segment)
self.flush(segment_duration, packet)
self.flush(packet, last_part=True)
# Reinitialize
self.reset(packet.dts)
@ -143,8 +140,7 @@ class SegmentBuffer:
def check_flush_part(self, packet: av.Packet) -> None:
"""Check for and mark a part segment boundary and record its duration."""
byte_position = self._memory_file.tell()
if self._segment_last_write_pos == byte_position:
if self._memory_file_pos == self._memory_file.tell():
return
if self._segment is None:
# We have our first non-zero byte position. This means the init has just
@ -154,43 +150,43 @@ class SegmentBuffer:
stream_id=self._stream_id,
init=self._memory_file.getvalue(),
)
self._segment_last_write_pos = byte_position
self._memory_file_pos = self._memory_file.tell()
self._part_start_dts = self._segment_start_dts
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
for stream_output in self._outputs_callback().values():
stream_output.put(self._segment)
else: # These are the ends of the part segments
self._segment.parts.append(
Part(
duration=float(
(packet.dts - self._part_start_dts) * packet.time_base
),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos : byte_position
].tobytes(),
)
)
self._segment_last_write_pos = byte_position
self._part_start_dts = packet.dts
self._part_has_keyframe = False
self.flush(packet, last_part=False)
def flush(self, duration: Fraction, packet: av.Packet) -> None:
"""Create a segment from the buffered packets and write to output."""
self._av_output.close()
def flush(self, packet: av.Packet, last_part: bool) -> None:
"""Output a part from the most recent bytes in the memory_file.
If last_part is True, also close the segment, give it a duration,
and clean up the av_output and memory_file.
"""
if last_part:
# Closing the av_output will write the remaining buffered data to the
# memory_file as a new moof/mdat.
self._av_output.close()
assert self._segment
self._segment.duration = float(duration)
# Also flush the part segment (need to close the output above before this)
self._memory_file.seek(self._memory_file_pos)
self._segment.parts.append(
Part(
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos :
].tobytes(),
data=self._memory_file.read(),
)
)
self._memory_file.close() # We don't need the BytesIO object anymore
if last_part:
self._segment.duration = float(
(packet.dts - self._segment_start_dts) * packet.time_base
)
self._memory_file.close() # We don't need the BytesIO object anymore
else:
self._memory_file_pos = self._memory_file.tell()
self._part_start_dts = packet.dts
self._part_has_keyframe = False
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""

View File

@ -264,7 +264,6 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER)
for i in range(2):
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
hls.put(segment)