mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Clean up stream refactor (#51951)
* Clean up target_duration method * Consolidate Part creation in one place * Use BytesIO.read instead of memoryview access * Change flush() signature
This commit is contained in:
parent
3836d46dff
commit
e8b5790846
@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from .const import ATTR_STREAMS, DOMAIN
|
||||
from .const import ATTR_STREAMS, DOMAIN, TARGET_SEGMENT_DURATION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Stream
|
||||
@ -28,6 +28,7 @@ class Part:
|
||||
|
||||
duration: float = attr.ib()
|
||||
has_keyframe: bool = attr.ib()
|
||||
# video data (moof+mdat)
|
||||
data: bytes = attr.ib()
|
||||
|
||||
|
||||
@ -50,7 +51,7 @@ class Segment:
|
||||
return self.duration > 0
|
||||
|
||||
def get_bytes_without_init(self) -> bytes:
|
||||
"""Return reconstructed data for entire segment as bytes."""
|
||||
"""Return reconstructed data for all parts as bytes, without init."""
|
||||
return b"".join([part.data for part in self.parts])
|
||||
|
||||
|
||||
@ -141,17 +142,16 @@ class StreamOutput:
|
||||
return None
|
||||
|
||||
@property
|
||||
def target_duration(self) -> int:
|
||||
def target_duration(self) -> float:
|
||||
"""Return the max duration of any given segment in seconds."""
|
||||
segment_length = len(self._segments)
|
||||
if not segment_length:
|
||||
return 1
|
||||
durations = [s.duration for s in self._segments]
|
||||
return round(max(durations)) or 1
|
||||
if not (durations := [s.duration for s in self._segments if s.complete]):
|
||||
return TARGET_SEGMENT_DURATION
|
||||
return max(durations)
|
||||
|
||||
def get_segment(self, sequence: int) -> Segment | None:
|
||||
"""Retrieve a specific segment."""
|
||||
for segment in self._segments:
|
||||
# Most hits will come in the most recent segments, so iterate reversed
|
||||
for segment in reversed(self._segments):
|
||||
if segment.sequence == sequence:
|
||||
return segment
|
||||
return None
|
||||
|
@ -5,7 +5,7 @@ from collections.abc import Generator
|
||||
|
||||
|
||||
def find_box(
|
||||
mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0
|
||||
mp4_bytes: bytes, target_type: bytes, box_start: int = 0
|
||||
) -> Generator[int, None, None]:
|
||||
"""Find location of first box (or sub_box if box_start provided) of given type."""
|
||||
if box_start == 0:
|
||||
|
@ -173,8 +173,9 @@ class HlsInitView(StreamView):
|
||||
track = stream.add_provider(HLS_PROVIDER)
|
||||
if not (segments := track.get_segments()):
|
||||
return web.HTTPNotFound()
|
||||
headers = {"Content-Type": "video/mp4"}
|
||||
return web.Response(body=segments[0].init, headers=headers)
|
||||
return web.Response(
|
||||
body=segments[0].init, headers={"Content-Type": "video/mp4"}
|
||||
)
|
||||
|
||||
|
||||
class HlsSegmentView(StreamView):
|
||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Iterator, Mapping
|
||||
from fractions import Fraction
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from threading import Event
|
||||
@ -49,7 +48,8 @@ class SegmentBuffer:
|
||||
self._output_video_stream: av.video.VideoStream = None
|
||||
self._output_audio_stream: Any | None = None # av.audio.AudioStream | None
|
||||
self._segment: Segment | None = None
|
||||
self._segment_last_write_pos: int = cast(int, None)
|
||||
# the following 3 member variables are used for Part formation
|
||||
self._memory_file_pos: int = cast(int, None)
|
||||
self._part_start_dts: int = cast(int, None)
|
||||
self._part_has_keyframe = False
|
||||
|
||||
@ -93,10 +93,10 @@ class SegmentBuffer:
|
||||
"""Initialize a new stream segment."""
|
||||
# Keep track of the number of segments we've processed
|
||||
self._sequence += 1
|
||||
self._segment_start_dts = self._part_start_dts = video_dts
|
||||
self._segment_start_dts = video_dts
|
||||
self._segment = None
|
||||
self._segment_last_write_pos = 0
|
||||
self._memory_file = BytesIO()
|
||||
self._memory_file_pos = 0
|
||||
self._av_output = self.make_new_av(
|
||||
memory_file=self._memory_file,
|
||||
sequence=self._sequence,
|
||||
@ -120,14 +120,11 @@ class SegmentBuffer:
|
||||
|
||||
if (
|
||||
packet.is_keyframe
|
||||
and (
|
||||
segment_duration := (packet.dts - self._segment_start_dts)
|
||||
* packet.time_base
|
||||
)
|
||||
and (packet.dts - self._segment_start_dts) * packet.time_base
|
||||
>= MIN_SEGMENT_DURATION
|
||||
):
|
||||
# Flush segment (also flushes the stub part segment)
|
||||
self.flush(segment_duration, packet)
|
||||
self.flush(packet, last_part=True)
|
||||
# Reinitialize
|
||||
self.reset(packet.dts)
|
||||
|
||||
@ -143,8 +140,7 @@ class SegmentBuffer:
|
||||
|
||||
def check_flush_part(self, packet: av.Packet) -> None:
|
||||
"""Check for and mark a part segment boundary and record its duration."""
|
||||
byte_position = self._memory_file.tell()
|
||||
if self._segment_last_write_pos == byte_position:
|
||||
if self._memory_file_pos == self._memory_file.tell():
|
||||
return
|
||||
if self._segment is None:
|
||||
# We have our first non-zero byte position. This means the init has just
|
||||
@ -154,43 +150,43 @@ class SegmentBuffer:
|
||||
stream_id=self._stream_id,
|
||||
init=self._memory_file.getvalue(),
|
||||
)
|
||||
self._segment_last_write_pos = byte_position
|
||||
self._memory_file_pos = self._memory_file.tell()
|
||||
self._part_start_dts = self._segment_start_dts
|
||||
# Fetch the latest StreamOutputs, which may have changed since the
|
||||
# worker started.
|
||||
for stream_output in self._outputs_callback().values():
|
||||
stream_output.put(self._segment)
|
||||
else: # These are the ends of the part segments
|
||||
self._segment.parts.append(
|
||||
Part(
|
||||
duration=float(
|
||||
(packet.dts - self._part_start_dts) * packet.time_base
|
||||
),
|
||||
has_keyframe=self._part_has_keyframe,
|
||||
data=self._memory_file.getbuffer()[
|
||||
self._segment_last_write_pos : byte_position
|
||||
].tobytes(),
|
||||
)
|
||||
)
|
||||
self._segment_last_write_pos = byte_position
|
||||
self._part_start_dts = packet.dts
|
||||
self._part_has_keyframe = False
|
||||
self.flush(packet, last_part=False)
|
||||
|
||||
def flush(self, duration: Fraction, packet: av.Packet) -> None:
|
||||
"""Create a segment from the buffered packets and write to output."""
|
||||
self._av_output.close()
|
||||
def flush(self, packet: av.Packet, last_part: bool) -> None:
|
||||
"""Output a part from the most recent bytes in the memory_file.
|
||||
|
||||
If last_part is True, also close the segment, give it a duration,
|
||||
and clean up the av_output and memory_file.
|
||||
"""
|
||||
if last_part:
|
||||
# Closing the av_output will write the remaining buffered data to the
|
||||
# memory_file as a new moof/mdat.
|
||||
self._av_output.close()
|
||||
assert self._segment
|
||||
self._segment.duration = float(duration)
|
||||
# Also flush the part segment (need to close the output above before this)
|
||||
self._memory_file.seek(self._memory_file_pos)
|
||||
self._segment.parts.append(
|
||||
Part(
|
||||
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
|
||||
has_keyframe=self._part_has_keyframe,
|
||||
data=self._memory_file.getbuffer()[
|
||||
self._segment_last_write_pos :
|
||||
].tobytes(),
|
||||
data=self._memory_file.read(),
|
||||
)
|
||||
)
|
||||
self._memory_file.close() # We don't need the BytesIO object anymore
|
||||
if last_part:
|
||||
self._segment.duration = float(
|
||||
(packet.dts - self._segment_start_dts) * packet.time_base
|
||||
)
|
||||
self._memory_file.close() # We don't need the BytesIO object anymore
|
||||
else:
|
||||
self._memory_file_pos = self._memory_file.tell()
|
||||
self._part_start_dts = packet.dts
|
||||
self._part_has_keyframe = False
|
||||
|
||||
def discontinuity(self) -> None:
|
||||
"""Mark the stream as having been restarted."""
|
||||
|
@ -264,7 +264,6 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
|
||||
stream = create_stream(hass, STREAM_SOURCE, {})
|
||||
stream_worker_sync.pause()
|
||||
hls = stream.add_provider(HLS_PROVIDER)
|
||||
|
||||
for i in range(2):
|
||||
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
|
||||
hls.put(segment)
|
||||
|
Loading…
x
Reference in New Issue
Block a user