mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 06:07:17 +00:00
Clean up stream refactor (#51951)
* Clean up target_duration method * Consolidate Part creation in one place * Use BytesIO.read instead of memoryview access * Change flush() signature
This commit is contained in:
parent
3836d46dff
commit
e8b5790846
@ -14,7 +14,7 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
|||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
from homeassistant.util.decorator import Registry
|
from homeassistant.util.decorator import Registry
|
||||||
|
|
||||||
from .const import ATTR_STREAMS, DOMAIN
|
from .const import ATTR_STREAMS, DOMAIN, TARGET_SEGMENT_DURATION
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from . import Stream
|
from . import Stream
|
||||||
@ -28,6 +28,7 @@ class Part:
|
|||||||
|
|
||||||
duration: float = attr.ib()
|
duration: float = attr.ib()
|
||||||
has_keyframe: bool = attr.ib()
|
has_keyframe: bool = attr.ib()
|
||||||
|
# video data (moof+mdat)
|
||||||
data: bytes = attr.ib()
|
data: bytes = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ class Segment:
|
|||||||
return self.duration > 0
|
return self.duration > 0
|
||||||
|
|
||||||
def get_bytes_without_init(self) -> bytes:
|
def get_bytes_without_init(self) -> bytes:
|
||||||
"""Return reconstructed data for entire segment as bytes."""
|
"""Return reconstructed data for all parts as bytes, without init."""
|
||||||
return b"".join([part.data for part in self.parts])
|
return b"".join([part.data for part in self.parts])
|
||||||
|
|
||||||
|
|
||||||
@ -141,17 +142,16 @@ class StreamOutput:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def target_duration(self) -> int:
|
def target_duration(self) -> float:
|
||||||
"""Return the max duration of any given segment in seconds."""
|
"""Return the max duration of any given segment in seconds."""
|
||||||
segment_length = len(self._segments)
|
if not (durations := [s.duration for s in self._segments if s.complete]):
|
||||||
if not segment_length:
|
return TARGET_SEGMENT_DURATION
|
||||||
return 1
|
return max(durations)
|
||||||
durations = [s.duration for s in self._segments]
|
|
||||||
return round(max(durations)) or 1
|
|
||||||
|
|
||||||
def get_segment(self, sequence: int) -> Segment | None:
|
def get_segment(self, sequence: int) -> Segment | None:
|
||||||
"""Retrieve a specific segment."""
|
"""Retrieve a specific segment."""
|
||||||
for segment in self._segments:
|
# Most hits will come in the most recent segments, so iterate reversed
|
||||||
|
for segment in reversed(self._segments):
|
||||||
if segment.sequence == sequence:
|
if segment.sequence == sequence:
|
||||||
return segment
|
return segment
|
||||||
return None
|
return None
|
||||||
|
@ -5,7 +5,7 @@ from collections.abc import Generator
|
|||||||
|
|
||||||
|
|
||||||
def find_box(
|
def find_box(
|
||||||
mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0
|
mp4_bytes: bytes, target_type: bytes, box_start: int = 0
|
||||||
) -> Generator[int, None, None]:
|
) -> Generator[int, None, None]:
|
||||||
"""Find location of first box (or sub_box if box_start provided) of given type."""
|
"""Find location of first box (or sub_box if box_start provided) of given type."""
|
||||||
if box_start == 0:
|
if box_start == 0:
|
||||||
|
@ -173,8 +173,9 @@ class HlsInitView(StreamView):
|
|||||||
track = stream.add_provider(HLS_PROVIDER)
|
track = stream.add_provider(HLS_PROVIDER)
|
||||||
if not (segments := track.get_segments()):
|
if not (segments := track.get_segments()):
|
||||||
return web.HTTPNotFound()
|
return web.HTTPNotFound()
|
||||||
headers = {"Content-Type": "video/mp4"}
|
return web.Response(
|
||||||
return web.Response(body=segments[0].init, headers=headers)
|
body=segments[0].init, headers={"Content-Type": "video/mp4"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HlsSegmentView(StreamView):
|
class HlsSegmentView(StreamView):
|
||||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Iterator, Mapping
|
from collections.abc import Iterator, Mapping
|
||||||
from fractions import Fraction
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
from threading import Event
|
from threading import Event
|
||||||
@ -49,7 +48,8 @@ class SegmentBuffer:
|
|||||||
self._output_video_stream: av.video.VideoStream = None
|
self._output_video_stream: av.video.VideoStream = None
|
||||||
self._output_audio_stream: Any | None = None # av.audio.AudioStream | None
|
self._output_audio_stream: Any | None = None # av.audio.AudioStream | None
|
||||||
self._segment: Segment | None = None
|
self._segment: Segment | None = None
|
||||||
self._segment_last_write_pos: int = cast(int, None)
|
# the following 3 member variables are used for Part formation
|
||||||
|
self._memory_file_pos: int = cast(int, None)
|
||||||
self._part_start_dts: int = cast(int, None)
|
self._part_start_dts: int = cast(int, None)
|
||||||
self._part_has_keyframe = False
|
self._part_has_keyframe = False
|
||||||
|
|
||||||
@ -93,10 +93,10 @@ class SegmentBuffer:
|
|||||||
"""Initialize a new stream segment."""
|
"""Initialize a new stream segment."""
|
||||||
# Keep track of the number of segments we've processed
|
# Keep track of the number of segments we've processed
|
||||||
self._sequence += 1
|
self._sequence += 1
|
||||||
self._segment_start_dts = self._part_start_dts = video_dts
|
self._segment_start_dts = video_dts
|
||||||
self._segment = None
|
self._segment = None
|
||||||
self._segment_last_write_pos = 0
|
|
||||||
self._memory_file = BytesIO()
|
self._memory_file = BytesIO()
|
||||||
|
self._memory_file_pos = 0
|
||||||
self._av_output = self.make_new_av(
|
self._av_output = self.make_new_av(
|
||||||
memory_file=self._memory_file,
|
memory_file=self._memory_file,
|
||||||
sequence=self._sequence,
|
sequence=self._sequence,
|
||||||
@ -120,14 +120,11 @@ class SegmentBuffer:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
packet.is_keyframe
|
packet.is_keyframe
|
||||||
and (
|
and (packet.dts - self._segment_start_dts) * packet.time_base
|
||||||
segment_duration := (packet.dts - self._segment_start_dts)
|
|
||||||
* packet.time_base
|
|
||||||
)
|
|
||||||
>= MIN_SEGMENT_DURATION
|
>= MIN_SEGMENT_DURATION
|
||||||
):
|
):
|
||||||
# Flush segment (also flushes the stub part segment)
|
# Flush segment (also flushes the stub part segment)
|
||||||
self.flush(segment_duration, packet)
|
self.flush(packet, last_part=True)
|
||||||
# Reinitialize
|
# Reinitialize
|
||||||
self.reset(packet.dts)
|
self.reset(packet.dts)
|
||||||
|
|
||||||
@ -143,8 +140,7 @@ class SegmentBuffer:
|
|||||||
|
|
||||||
def check_flush_part(self, packet: av.Packet) -> None:
|
def check_flush_part(self, packet: av.Packet) -> None:
|
||||||
"""Check for and mark a part segment boundary and record its duration."""
|
"""Check for and mark a part segment boundary and record its duration."""
|
||||||
byte_position = self._memory_file.tell()
|
if self._memory_file_pos == self._memory_file.tell():
|
||||||
if self._segment_last_write_pos == byte_position:
|
|
||||||
return
|
return
|
||||||
if self._segment is None:
|
if self._segment is None:
|
||||||
# We have our first non-zero byte position. This means the init has just
|
# We have our first non-zero byte position. This means the init has just
|
||||||
@ -154,43 +150,43 @@ class SegmentBuffer:
|
|||||||
stream_id=self._stream_id,
|
stream_id=self._stream_id,
|
||||||
init=self._memory_file.getvalue(),
|
init=self._memory_file.getvalue(),
|
||||||
)
|
)
|
||||||
self._segment_last_write_pos = byte_position
|
self._memory_file_pos = self._memory_file.tell()
|
||||||
|
self._part_start_dts = self._segment_start_dts
|
||||||
# Fetch the latest StreamOutputs, which may have changed since the
|
# Fetch the latest StreamOutputs, which may have changed since the
|
||||||
# worker started.
|
# worker started.
|
||||||
for stream_output in self._outputs_callback().values():
|
for stream_output in self._outputs_callback().values():
|
||||||
stream_output.put(self._segment)
|
stream_output.put(self._segment)
|
||||||
else: # These are the ends of the part segments
|
else: # These are the ends of the part segments
|
||||||
self._segment.parts.append(
|
self.flush(packet, last_part=False)
|
||||||
Part(
|
|
||||||
duration=float(
|
|
||||||
(packet.dts - self._part_start_dts) * packet.time_base
|
|
||||||
),
|
|
||||||
has_keyframe=self._part_has_keyframe,
|
|
||||||
data=self._memory_file.getbuffer()[
|
|
||||||
self._segment_last_write_pos : byte_position
|
|
||||||
].tobytes(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._segment_last_write_pos = byte_position
|
|
||||||
self._part_start_dts = packet.dts
|
|
||||||
self._part_has_keyframe = False
|
|
||||||
|
|
||||||
def flush(self, duration: Fraction, packet: av.Packet) -> None:
|
def flush(self, packet: av.Packet, last_part: bool) -> None:
|
||||||
"""Create a segment from the buffered packets and write to output."""
|
"""Output a part from the most recent bytes in the memory_file.
|
||||||
|
|
||||||
|
If last_part is True, also close the segment, give it a duration,
|
||||||
|
and clean up the av_output and memory_file.
|
||||||
|
"""
|
||||||
|
if last_part:
|
||||||
|
# Closing the av_output will write the remaining buffered data to the
|
||||||
|
# memory_file as a new moof/mdat.
|
||||||
self._av_output.close()
|
self._av_output.close()
|
||||||
assert self._segment
|
assert self._segment
|
||||||
self._segment.duration = float(duration)
|
self._memory_file.seek(self._memory_file_pos)
|
||||||
# Also flush the part segment (need to close the output above before this)
|
|
||||||
self._segment.parts.append(
|
self._segment.parts.append(
|
||||||
Part(
|
Part(
|
||||||
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
|
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
|
||||||
has_keyframe=self._part_has_keyframe,
|
has_keyframe=self._part_has_keyframe,
|
||||||
data=self._memory_file.getbuffer()[
|
data=self._memory_file.read(),
|
||||||
self._segment_last_write_pos :
|
|
||||||
].tobytes(),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if last_part:
|
||||||
|
self._segment.duration = float(
|
||||||
|
(packet.dts - self._segment_start_dts) * packet.time_base
|
||||||
|
)
|
||||||
self._memory_file.close() # We don't need the BytesIO object anymore
|
self._memory_file.close() # We don't need the BytesIO object anymore
|
||||||
|
else:
|
||||||
|
self._memory_file_pos = self._memory_file.tell()
|
||||||
|
self._part_start_dts = packet.dts
|
||||||
|
self._part_has_keyframe = False
|
||||||
|
|
||||||
def discontinuity(self) -> None:
|
def discontinuity(self) -> None:
|
||||||
"""Mark the stream as having been restarted."""
|
"""Mark the stream as having been restarted."""
|
||||||
|
@ -264,7 +264,6 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
|
|||||||
stream = create_stream(hass, STREAM_SOURCE, {})
|
stream = create_stream(hass, STREAM_SOURCE, {})
|
||||||
stream_worker_sync.pause()
|
stream_worker_sync.pause()
|
||||||
hls = stream.add_provider(HLS_PROVIDER)
|
hls = stream.add_provider(HLS_PROVIDER)
|
||||||
|
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
|
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
|
||||||
hls.put(segment)
|
hls.put(segment)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user