Refactor stream to use bytes (#51066)

* Refactor stream to use bytes
This commit is contained in:
uvjustin 2021-05-26 16:19:09 +08:00 committed by GitHub
parent 58586d5e1f
commit c6f108f7c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 91 additions and 75 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
import io
from typing import Callable from typing import Callable
from aiohttp import web from aiohttp import web
@ -19,12 +18,15 @@ from .const import ATTR_STREAMS, DOMAIN
PROVIDERS = Registry() PROVIDERS = Registry()
@attr.s @attr.s(slots=True)
class Segment: class Segment:
"""Represent a segment.""" """Represent a segment."""
sequence: int = attr.ib() sequence: int = attr.ib()
segment: io.BytesIO = attr.ib() # the init of the mp4
init: bytes = attr.ib()
# the video data (moof + mddat)s of the mp4
moof_data: bytes = attr.ib()
duration: float = attr.ib() duration: float = attr.ib()
# For detecting discontinuities across stream restarts # For detecting discontinuities across stream restarts
stream_id: int = attr.ib(default=0) stream_id: int = attr.ib(default=0)

View File

@ -2,67 +2,59 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator from collections.abc import Generator
import io
def find_box( def find_box(
segment: io.BytesIO, target_type: bytes, box_start: int = 0 mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0
) -> Generator[int, None, None]: ) -> Generator[int, None, None]:
"""Find location of first box (or sub_box if box_start provided) of given type.""" """Find location of first box (or sub_box if box_start provided) of given type."""
if box_start == 0: if box_start == 0:
box_end = segment.seek(0, io.SEEK_END)
segment.seek(0)
index = 0 index = 0
box_end = len(mp4_bytes)
else: else:
segment.seek(box_start) box_end = box_start + int.from_bytes(
box_end = box_start + int.from_bytes(segment.read(4), byteorder="big") mp4_bytes[box_start : box_start + 4], byteorder="big"
)
index = box_start + 8 index = box_start + 8
while 1: while 1:
if index > box_end - 8: # End of box, not found if index > box_end - 8: # End of box, not found
break break
segment.seek(index) box_header = mp4_bytes[index : index + 8]
box_header = segment.read(8)
if box_header[4:8] == target_type: if box_header[4:8] == target_type:
yield index yield index
segment.seek(index)
index += int.from_bytes(box_header[0:4], byteorder="big") index += int.from_bytes(box_header[0:4], byteorder="big")
def get_init(segment: io.BytesIO) -> bytes: def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]:
"""Get init section from fragmented mp4.""" """Get the init and moof data from a segment."""
moof_location = next(find_box(segment, b"moof")) moof_location = next(find_box(segment, b"moof"), 0)
segment.seek(0) mfra_location = next(find_box(segment, b"mfra"), len(segment))
return segment.read(moof_location) return (
segment[:moof_location].tobytes(),
segment[moof_location:mfra_location].tobytes(),
)
def get_m4s(segment: io.BytesIO, sequence: int) -> bytes: def get_codec_string(mp4_bytes: bytes) -> str:
"""Get m4s section from fragmented mp4."""
moof_location = next(find_box(segment, b"moof"))
mfra_location = next(find_box(segment, b"mfra"))
segment.seek(moof_location)
return segment.read(mfra_location - moof_location)
def get_codec_string(segment: io.BytesIO) -> str:
"""Get RFC 6381 codec string.""" """Get RFC 6381 codec string."""
codecs = [] codecs = []
# Find moov # Find moov
moov_location = next(find_box(segment, b"moov")) moov_location = next(find_box(mp4_bytes, b"moov"))
# Find tracks # Find tracks
for trak_location in find_box(segment, b"trak", moov_location): for trak_location in find_box(mp4_bytes, b"trak", moov_location):
# Drill down to media info # Drill down to media info
mdia_location = next(find_box(segment, b"mdia", trak_location)) mdia_location = next(find_box(mp4_bytes, b"mdia", trak_location))
minf_location = next(find_box(segment, b"minf", mdia_location)) minf_location = next(find_box(mp4_bytes, b"minf", mdia_location))
stbl_location = next(find_box(segment, b"stbl", minf_location)) stbl_location = next(find_box(mp4_bytes, b"stbl", minf_location))
stsd_location = next(find_box(segment, b"stsd", stbl_location)) stsd_location = next(find_box(mp4_bytes, b"stsd", stbl_location))
# Get stsd box # Get stsd box
segment.seek(stsd_location) stsd_length = int.from_bytes(
stsd_length = int.from_bytes(segment.read(4), byteorder="big") mp4_bytes[stsd_location : stsd_location + 4], byteorder="big"
segment.seek(stsd_location) )
stsd_box = segment.read(stsd_length) stsd_box = mp4_bytes[stsd_location : stsd_location + stsd_length]
# Base Codec # Base Codec
codec = stsd_box[20:24].decode("utf-8") codec = stsd_box[20:24].decode("utf-8")

View File

@ -1,13 +1,11 @@
"""Provide functionality to stream HLS.""" """Provide functionality to stream HLS."""
import io
from aiohttp import web from aiohttp import web
from homeassistant.core import callback from homeassistant.core import callback
from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS
from .core import PROVIDERS, HomeAssistant, IdleTimer, StreamOutput, StreamView from .core import PROVIDERS, HomeAssistant, IdleTimer, StreamOutput, StreamView
from .fmp4utils import get_codec_string, get_init, get_m4s from .fmp4utils import get_codec_string
@callback @callback
@ -35,9 +33,9 @@ class HlsMasterPlaylistView(StreamView):
# hls spec already allows for 25% variation # hls spec already allows for 25% variation
segment = track.get_segment(track.segments[-1]) segment = track.get_segment(track.segments[-1])
bandwidth = round( bandwidth = round(
segment.segment.seek(0, io.SEEK_END) * 8 / segment.duration * 1.2 (len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2
) )
codecs = get_codec_string(segment.segment) codecs = get_codec_string(segment.init)
lines = [ lines = [
"#EXTM3U", "#EXTM3U",
f'#EXT-X-STREAM-INF:BANDWIDTH={bandwidth},CODECS="{codecs}"', f'#EXT-X-STREAM-INF:BANDWIDTH={bandwidth},CODECS="{codecs}"',
@ -129,7 +127,7 @@ class HlsInitView(StreamView):
if not segments: if not segments:
return web.HTTPNotFound() return web.HTTPNotFound()
headers = {"Content-Type": "video/mp4"} headers = {"Content-Type": "video/mp4"}
return web.Response(body=get_init(segments[0].segment), headers=headers) return web.Response(body=segments[0].init, headers=headers)
class HlsSegmentView(StreamView): class HlsSegmentView(StreamView):
@ -147,7 +145,7 @@ class HlsSegmentView(StreamView):
return web.HTTPNotFound() return web.HTTPNotFound()
headers = {"Content-Type": "video/iso.segment"} headers = {"Content-Type": "video/iso.segment"}
return web.Response( return web.Response(
body=get_m4s(segment.segment, int(sequence)), body=segment.moof_data,
headers=headers, headers=headers,
) )

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import deque
from io import BytesIO
import logging import logging
import os import os
import threading import threading
@ -51,7 +52,11 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]):
last_sequence = segment.sequence last_sequence = segment.sequence
# Open segment # Open segment
source = av.open(segment.segment, "r", format=SEGMENT_CONTAINER_FORMAT) source = av.open(
BytesIO(segment.init + segment.moof_data),
"r",
format=SEGMENT_CONTAINER_FORMAT,
)
source_v = source.streams.video[0] source_v = source.streams.video[0]
source_a = source.streams.audio[0] if len(source.streams.audio) > 0 else None source_a = source.streams.audio[0] if len(source.streams.audio) > 0 else None

View File

@ -19,6 +19,7 @@ from .const import (
STREAM_TIMEOUT, STREAM_TIMEOUT,
) )
from .core import Segment, StreamOutput from .core import Segment, StreamOutput
from .fmp4utils import get_init_and_moof_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -29,8 +30,6 @@ class SegmentBuffer:
def __init__(self, outputs_callback) -> None: def __init__(self, outputs_callback) -> None:
"""Initialize SegmentBuffer.""" """Initialize SegmentBuffer."""
self._stream_id = 0 self._stream_id = 0
self._video_stream = None
self._audio_stream = None
self._outputs_callback = outputs_callback self._outputs_callback = outputs_callback
self._outputs: list[StreamOutput] = [] self._outputs: list[StreamOutput] = []
self._sequence = 0 self._sequence = 0
@ -41,10 +40,11 @@ class SegmentBuffer:
self._input_audio_stream = None # av.audio.AudioStream | None self._input_audio_stream = None # av.audio.AudioStream | None
self._output_video_stream: av.video.VideoStream = None self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream = None # av.audio.AudioStream | None self._output_audio_stream = None # av.audio.AudioStream | None
self._segment: Segment = cast(Segment, None)
@staticmethod @staticmethod
def make_new_av( def make_new_av(
memory_file, sequence: int, input_vstream: av.video.VideoStream memory_file: BytesIO, sequence: int, input_vstream: av.video.VideoStream
) -> av.container.OutputContainer: ) -> av.container.OutputContainer:
"""Make a new av OutputContainer.""" """Make a new av OutputContainer."""
return av.open( return av.open(
@ -120,7 +120,13 @@ class SegmentBuffer:
def flush(self, duration): def flush(self, duration):
"""Create a segment from the buffered packets and write to output.""" """Create a segment from the buffered packets and write to output."""
self._av_output.close() self._av_output.close()
segment = Segment(self._sequence, self._memory_file, duration, self._stream_id) segment = Segment(
self._sequence,
*get_init_and_moof_data(self._memory_file.getbuffer()),
duration,
self._stream_id,
)
self._memory_file.close()
for stream_output in self._outputs: for stream_output in self._outputs:
stream_output.put(segment) stream_output.put(segment)
@ -134,6 +140,7 @@ class SegmentBuffer:
def close(self): def close(self):
"""Close stream buffer.""" """Close stream buffer."""
self._av_output.close() self._av_output.close()
self._memory_file.close()
def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901

View File

@ -1,6 +1,5 @@
"""The tests for hls streams.""" """The tests for hls streams."""
from datetime import timedelta from datetime import timedelta
import io
from unittest.mock import patch from unittest.mock import patch
from urllib.parse import urlparse from urllib.parse import urlparse
@ -18,7 +17,8 @@ from tests.common import async_fire_time_changed
from tests.components.stream.common import generate_h264_video from tests.components.stream.common import generate_h264_video
STREAM_SOURCE = "some-stream-source" STREAM_SOURCE = "some-stream-source"
SEQUENCE_BYTES = io.BytesIO(b"some-bytes") INIT_BYTES = b"init"
MOOF_BYTES = b"some-bytes"
DURATION = 10 DURATION = 10
TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
@ -248,7 +248,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider("hls") hls = stream.add_provider("hls")
hls.put(Segment(1, SEQUENCE_BYTES, DURATION)) hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION))
await hass.async_block_till_done() await hass.async_block_till_done()
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
@ -257,7 +257,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
assert resp.status == 200 assert resp.status == 200
assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)]) assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)])
hls.put(Segment(2, SEQUENCE_BYTES, DURATION)) hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, DURATION))
await hass.async_block_till_done() await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200 assert resp.status == 200
@ -281,7 +281,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
# Produce enough segments to overfill the output buffer by one # Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2): for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION)) hls.put(Segment(sequence, INIT_BYTES, MOOF_BYTES, DURATION))
await hass.async_block_till_done() await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")
@ -297,18 +297,14 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
segments=segments, segments=segments,
) )
# Fetch the actual segments with a fake byte payload # The segment that fell off the buffer is not accessible
with patch( segment_response = await hls_client.get("/segment/1.m4s")
"homeassistant.components.stream.hls.get_m4s", return_value=b"fake-payload" assert segment_response.status == 404
):
# The segment that fell off the buffer is not accessible
segment_response = await hls_client.get("/segment/1.m4s")
assert segment_response.status == 404
# However all segments in the buffer are accessible, even those that were not in the playlist. # However all segments in the buffer are accessible, even those that were not in the playlist.
for sequence in range(2, MAX_SEGMENTS + 2): for sequence in range(2, MAX_SEGMENTS + 2):
segment_response = await hls_client.get(f"/segment/{sequence}.m4s") segment_response = await hls_client.get(f"/segment/{sequence}.m4s")
assert segment_response.status == 200 assert segment_response.status == 200
stream_worker_sync.resume() stream_worker_sync.resume()
stream.stop() stream.stop()
@ -322,9 +318,9 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
stream_worker_sync.pause() stream_worker_sync.pause()
hls = stream.add_provider("hls") hls = stream.add_provider("hls")
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0))
hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0)) hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0))
hls.put(Segment(3, SEQUENCE_BYTES, DURATION, stream_id=1)) hls.put(Segment(3, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=1))
await hass.async_block_till_done() await hass.async_block_till_done()
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
@ -354,11 +350,11 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0))
# Produce enough segments to overfill the output buffer by one # Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2): for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION, stream_id=1)) hls.put(Segment(sequence, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=1))
await hass.async_block_till_done() await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8") resp = await hls_client.get("/playlist.m3u8")

View File

@ -1,10 +1,13 @@
"""The tests for hls streams.""" """The tests for hls streams."""
from __future__ import annotations
import asyncio import asyncio
from collections import deque
from datetime import timedelta from datetime import timedelta
from io import BytesIO
import logging import logging
import os import os
import threading import threading
from typing import Deque
from unittest.mock import patch from unittest.mock import patch
import async_timeout import async_timeout
@ -13,6 +16,7 @@ import pytest
from homeassistant.components.stream import create_stream from homeassistant.components.stream import create_stream
from homeassistant.components.stream.core import Segment from homeassistant.components.stream.core import Segment
from homeassistant.components.stream.fmp4utils import get_init_and_moof_data
from homeassistant.components.stream.recorder import recorder_save_worker from homeassistant.components.stream.recorder import recorder_save_worker
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -37,8 +41,9 @@ class SaveRecordWorkerSync:
"""Initialize SaveRecordWorkerSync.""" """Initialize SaveRecordWorkerSync."""
self.reset() self.reset()
self._segments = None self._segments = None
self._save_thread = None
def recorder_save_worker(self, file_out: str, segments: Deque[Segment]): def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch.""" """Mock method for patch."""
logging.debug("recorder_save_worker thread started") logging.debug("recorder_save_worker thread started")
assert self._save_thread is None assert self._save_thread is None
@ -180,7 +185,9 @@ async def test_recorder_save(tmpdir):
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
# Run # Run
recorder_save_worker(filename, [Segment(1, source, 4)]) recorder_save_worker(
filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)]
)
# Assert # Assert
assert os.path.exists(filename) assert os.path.exists(filename)
@ -193,13 +200,20 @@ async def test_recorder_discontinuity(tmpdir):
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
# Run # Run
recorder_save_worker(filename, [Segment(1, source, 4, 0), Segment(2, source, 4, 1)]) init, moof_data = get_init_and_moof_data(source.getbuffer())
recorder_save_worker(
filename,
[
Segment(1, init, moof_data, 4, 0),
Segment(2, init, moof_data, 4, 1),
],
)
# Assert # Assert
assert os.path.exists(filename) assert os.path.exists(filename)
async def test_recorder_no_segements(tmpdir): async def test_recorder_no_segments(tmpdir):
"""Test recorder behavior with a stream failure which causes no segments.""" """Test recorder behavior with a stream failure which causes no segments."""
# Setup # Setup
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
@ -247,7 +261,9 @@ async def test_record_stream_audio(
last_segment = segment last_segment = segment
stream_worker_sync.resume() stream_worker_sync.resume()
result = av.open(last_segment.segment, "r", format="mp4") result = av.open(
BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4"
)
assert len(result.streams.audio) == expected_audio_streams assert len(result.streams.audio) == expected_audio_streams
result.close() result.close()