diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 0d29474858f..695f1d05ac3 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio from collections import deque -import io from typing import Callable from aiohttp import web @@ -19,12 +18,15 @@ from .const import ATTR_STREAMS, DOMAIN PROVIDERS = Registry() -@attr.s +@attr.s(slots=True) class Segment: """Represent a segment.""" sequence: int = attr.ib() - segment: io.BytesIO = attr.ib() + # the init of the mp4 + init: bytes = attr.ib() + # the video data (moof + mddat)s of the mp4 + moof_data: bytes = attr.ib() duration: float = attr.ib() # For detecting discontinuities across stream restarts stream_id: int = attr.ib(default=0) diff --git a/homeassistant/components/stream/fmp4utils.py b/homeassistant/components/stream/fmp4utils.py index ad5b100ce77..511bbc0939a 100644 --- a/homeassistant/components/stream/fmp4utils.py +++ b/homeassistant/components/stream/fmp4utils.py @@ -2,67 +2,59 @@ from __future__ import annotations from collections.abc import Generator -import io def find_box( - segment: io.BytesIO, target_type: bytes, box_start: int = 0 + mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0 ) -> Generator[int, None, None]: """Find location of first box (or sub_box if box_start provided) of given type.""" if box_start == 0: - box_end = segment.seek(0, io.SEEK_END) - segment.seek(0) index = 0 + box_end = len(mp4_bytes) else: - segment.seek(box_start) - box_end = box_start + int.from_bytes(segment.read(4), byteorder="big") + box_end = box_start + int.from_bytes( + mp4_bytes[box_start : box_start + 4], byteorder="big" + ) index = box_start + 8 while 1: if index > box_end - 8: # End of box, not found break - segment.seek(index) - box_header = segment.read(8) + box_header = mp4_bytes[index : index + 8] if box_header[4:8] == target_type: yield index - segment.seek(index) index += int.from_bytes(box_header[0:4], byteorder="big") -def get_init(segment: io.BytesIO) -> bytes: - """Get init section from fragmented mp4.""" - moof_location = next(find_box(segment, b"moof")) - segment.seek(0) - return segment.read(moof_location) +def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]: + """Get the init and moof data from a segment.""" + moof_location = next(find_box(segment, b"moof"), 0) + mfra_location = next(find_box(segment, b"mfra"), len(segment)) + return ( + segment[:moof_location].tobytes(), + segment[moof_location:mfra_location].tobytes(), + ) -def get_m4s(segment: io.BytesIO, sequence: int) -> bytes: - """Get m4s section from fragmented mp4.""" - moof_location = next(find_box(segment, b"moof")) - mfra_location = next(find_box(segment, b"mfra")) - segment.seek(moof_location) - return segment.read(mfra_location - moof_location) - - -def get_codec_string(segment: io.BytesIO) -> str: +def get_codec_string(mp4_bytes: bytes) -> str: """Get RFC 6381 codec string.""" codecs = [] # Find moov - moov_location = next(find_box(segment, b"moov")) + moov_location = next(find_box(mp4_bytes, b"moov")) # Find tracks - for trak_location in find_box(segment, b"trak", moov_location): + for trak_location in find_box(mp4_bytes, b"trak", moov_location): # Drill down to media info - mdia_location = next(find_box(segment, b"mdia", trak_location)) - minf_location = next(find_box(segment, b"minf", mdia_location)) - stbl_location = next(find_box(segment, b"stbl", minf_location)) - stsd_location = next(find_box(segment, b"stsd", stbl_location)) + mdia_location = next(find_box(mp4_bytes, b"mdia", trak_location)) + minf_location = next(find_box(mp4_bytes, b"minf", mdia_location)) + stbl_location = next(find_box(mp4_bytes, b"stbl", minf_location)) + stsd_location = next(find_box(mp4_bytes, b"stsd", stbl_location)) # Get stsd box - segment.seek(stsd_location) - stsd_length = int.from_bytes(segment.read(4), byteorder="big") - segment.seek(stsd_location) - stsd_box = segment.read(stsd_length) + stsd_length = int.from_bytes( + mp4_bytes[stsd_location : stsd_location + 4], byteorder="big" + ) + stsd_box = mp4_bytes[stsd_location : stsd_location + stsd_length] # Base Codec codec = stsd_box[20:24].decode("utf-8") diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index 42f7f2dbfa3..941f4407423 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -1,13 +1,11 @@ """Provide functionality to stream HLS.""" -import io - from aiohttp import web from homeassistant.core import callback from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS from .core import PROVIDERS, HomeAssistant, IdleTimer, StreamOutput, StreamView -from .fmp4utils import get_codec_string, get_init, get_m4s +from .fmp4utils import get_codec_string @callback @@ -35,9 +33,9 @@ class HlsMasterPlaylistView(StreamView): # hls spec already allows for 25% variation segment = track.get_segment(track.segments[-1]) bandwidth = round( - segment.segment.seek(0, io.SEEK_END) * 8 / segment.duration * 1.2 + (len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2 ) - codecs = get_codec_string(segment.segment) + codecs = get_codec_string(segment.init) lines = [ "#EXTM3U", f'#EXT-X-STREAM-INF:BANDWIDTH={bandwidth},CODECS="{codecs}"', @@ -129,7 +127,7 @@ class HlsInitView(StreamView): if not segments: return web.HTTPNotFound() headers = {"Content-Type": "video/mp4"} - return web.Response(body=get_init(segments[0].segment), headers=headers) + return web.Response(body=segments[0].init, headers=headers) class HlsSegmentView(StreamView): @@ -147,7 +145,7 @@ class HlsSegmentView(StreamView): return web.HTTPNotFound() headers = {"Content-Type": "video/iso.segment"} return web.Response( - body=get_m4s(segment.segment, int(sequence)), + body=segment.moof_data, headers=headers, ) diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py index 085a6448597..7d849375ece 100644 --- a/homeassistant/components/stream/recorder.py +++ b/homeassistant/components/stream/recorder.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections import deque +from io import BytesIO import logging import os import threading @@ -51,7 +52,11 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]): last_sequence = segment.sequence # Open segment - source = av.open(segment.segment, "r", format=SEGMENT_CONTAINER_FORMAT) + source = av.open( + BytesIO(segment.init + segment.moof_data), + "r", + format=SEGMENT_CONTAINER_FORMAT, + ) source_v = source.streams.video[0] source_a = source.streams.audio[0] if len(source.streams.audio) > 0 else None diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index d6562cf93db..cb6d6a6a017 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -19,6 +19,7 @@ from .const import ( STREAM_TIMEOUT, ) from .core import Segment, StreamOutput +from .fmp4utils import get_init_and_moof_data _LOGGER = logging.getLogger(__name__) @@ -29,8 +30,6 @@ class SegmentBuffer: def __init__(self, outputs_callback) -> None: """Initialize SegmentBuffer.""" self._stream_id = 0 - self._video_stream = None - self._audio_stream = None self._outputs_callback = outputs_callback self._outputs: list[StreamOutput] = [] self._sequence = 0 @@ -41,10 +40,11 @@ class SegmentBuffer: self._input_audio_stream = None # av.audio.AudioStream | None self._output_video_stream: av.video.VideoStream = None self._output_audio_stream = None # av.audio.AudioStream | None + self._segment: Segment = cast(Segment, None) @staticmethod def make_new_av( - memory_file, sequence: int, input_vstream: av.video.VideoStream + memory_file: BytesIO, sequence: int, input_vstream: av.video.VideoStream ) -> av.container.OutputContainer: """Make a new av OutputContainer.""" return av.open( @@ -120,7 +120,13 @@ class SegmentBuffer: def flush(self, duration): """Create a segment from the buffered packets and write to output.""" self._av_output.close() - segment = Segment(self._sequence, self._memory_file, duration, self._stream_id) + segment = Segment( + self._sequence, + *get_init_and_moof_data(self._memory_file.getbuffer()), + duration, + self._stream_id, + ) + self._memory_file.close() for stream_output in self._outputs: stream_output.put(segment) @@ -134,6 +140,7 @@ class SegmentBuffer: def close(self): """Close stream buffer.""" self._av_output.close() + self._memory_file.close() def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index ab0c21efdfb..f9b96a662d9 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -1,6 +1,5 @@ """The tests for hls streams.""" from datetime import timedelta -import io from unittest.mock import patch from urllib.parse import urlparse @@ -18,7 +17,8 @@ from tests.common import async_fire_time_changed from tests.components.stream.common import generate_h264_video STREAM_SOURCE = "some-stream-source" -SEQUENCE_BYTES = io.BytesIO(b"some-bytes") +INIT_BYTES = b"init" +MOOF_BYTES = b"some-bytes" DURATION = 10 TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever @@ -248,7 +248,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): stream_worker_sync.pause() hls = stream.add_provider("hls") - hls.put(Segment(1, SEQUENCE_BYTES, DURATION)) + hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION)) await hass.async_block_till_done() hls_client = await hls_stream(stream) @@ -257,7 +257,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): assert resp.status == 200 assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)]) - hls.put(Segment(2, SEQUENCE_BYTES, DURATION)) + hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, DURATION)) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 @@ -281,7 +281,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): # Produce enough segments to overfill the output buffer by one for sequence in range(1, MAX_SEGMENTS + 2): - hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION)) + hls.put(Segment(sequence, INIT_BYTES, MOOF_BYTES, DURATION)) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") @@ -297,18 +297,14 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): segments=segments, ) - # Fetch the actual segments with a fake byte payload - with patch( - "homeassistant.components.stream.hls.get_m4s", return_value=b"fake-payload" - ): - # The segment that fell off the buffer is not accessible - segment_response = await hls_client.get("/segment/1.m4s") - assert segment_response.status == 404 + # The segment that fell off the buffer is not accessible + segment_response = await hls_client.get("/segment/1.m4s") + assert segment_response.status == 404 - # However all segments in the buffer are accessible, even those that were not in the playlist. - for sequence in range(2, MAX_SEGMENTS + 2): - segment_response = await hls_client.get(f"/segment/{sequence}.m4s") - assert segment_response.status == 200 + # However all segments in the buffer are accessible, even those that were not in the playlist. + for sequence in range(2, MAX_SEGMENTS + 2): + segment_response = await hls_client.get(f"/segment/{sequence}.m4s") + assert segment_response.status == 200 stream_worker_sync.resume() stream.stop() @@ -322,9 +318,9 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s stream_worker_sync.pause() hls = stream.add_provider("hls") - hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) - hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0)) - hls.put(Segment(3, SEQUENCE_BYTES, DURATION, stream_id=1)) + hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0)) + hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0)) + hls.put(Segment(3, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=1)) await hass.async_block_till_done() hls_client = await hls_stream(stream) @@ -354,11 +350,11 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy hls_client = await hls_stream(stream) - hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) + hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0)) # Produce enough segments to overfill the output buffer by one for sequence in range(1, MAX_SEGMENTS + 2): - hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION, stream_id=1)) + hls.put(Segment(sequence, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=1)) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index 5ee055754b9..9097d03a7a9 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -1,10 +1,13 @@ """The tests for hls streams.""" +from __future__ import annotations + import asyncio +from collections import deque from datetime import timedelta +from io import BytesIO import logging import os import threading -from typing import Deque from unittest.mock import patch import async_timeout @@ -13,6 +16,7 @@ import pytest from homeassistant.components.stream import create_stream from homeassistant.components.stream.core import Segment +from homeassistant.components.stream.fmp4utils import get_init_and_moof_data from homeassistant.components.stream.recorder import recorder_save_worker from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component @@ -37,8 +41,9 @@ class SaveRecordWorkerSync: """Initialize SaveRecordWorkerSync.""" self.reset() self._segments = None + self._save_thread = None - def recorder_save_worker(self, file_out: str, segments: Deque[Segment]): + def recorder_save_worker(self, file_out: str, segments: deque[Segment]): """Mock method for patch.""" logging.debug("recorder_save_worker thread started") assert self._save_thread is None @@ -180,7 +185,9 @@ async def test_recorder_save(tmpdir): filename = f"{tmpdir}/test.mp4" # Run - recorder_save_worker(filename, [Segment(1, source, 4)]) + recorder_save_worker( + filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)] + ) # Assert assert os.path.exists(filename) @@ -193,13 +200,20 @@ async def test_recorder_discontinuity(tmpdir): filename = f"{tmpdir}/test.mp4" # Run - recorder_save_worker(filename, [Segment(1, source, 4, 0), Segment(2, source, 4, 1)]) + init, moof_data = get_init_and_moof_data(source.getbuffer()) + recorder_save_worker( + filename, + [ + Segment(1, init, moof_data, 4, 0), + Segment(2, init, moof_data, 4, 1), + ], + ) # Assert assert os.path.exists(filename) -async def test_recorder_no_segements(tmpdir): +async def test_recorder_no_segments(tmpdir): """Test recorder behavior with a stream failure which causes no segments.""" # Setup filename = f"{tmpdir}/test.mp4" @@ -247,7 +261,9 @@ async def test_record_stream_audio( last_segment = segment stream_worker_sync.resume() - result = av.open(last_segment.segment, "r", format="mp4") + result = av.open( + BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4" + ) assert len(result.streams.audio) == expected_audio_streams result.close()