Refactor stream to use bytes (#51066)

* Refactor stream to use bytes
This commit is contained in:
uvjustin 2021-05-26 16:19:09 +08:00 committed by GitHub
parent 58586d5e1f
commit c6f108f7c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 91 additions and 75 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
from collections import deque
import io
from typing import Callable
from aiohttp import web
@ -19,12 +18,15 @@ from .const import ATTR_STREAMS, DOMAIN
PROVIDERS = Registry()
@attr.s
@attr.s(slots=True)
class Segment:
"""Represent a segment."""
sequence: int = attr.ib()
segment: io.BytesIO = attr.ib()
# the init of the mp4
init: bytes = attr.ib()
# the video data (moof + mddat)s of the mp4
moof_data: bytes = attr.ib()
duration: float = attr.ib()
# For detecting discontinuities across stream restarts
stream_id: int = attr.ib(default=0)

View File

@ -2,67 +2,59 @@
from __future__ import annotations
from collections.abc import Generator
import io
def find_box(
segment: io.BytesIO, target_type: bytes, box_start: int = 0
mp4_bytes: bytes | memoryview, target_type: bytes, box_start: int = 0
) -> Generator[int, None, None]:
"""Find location of first box (or sub_box if box_start provided) of given type."""
if box_start == 0:
box_end = segment.seek(0, io.SEEK_END)
segment.seek(0)
index = 0
box_end = len(mp4_bytes)
else:
segment.seek(box_start)
box_end = box_start + int.from_bytes(segment.read(4), byteorder="big")
box_end = box_start + int.from_bytes(
mp4_bytes[box_start : box_start + 4], byteorder="big"
)
index = box_start + 8
while 1:
if index > box_end - 8: # End of box, not found
break
segment.seek(index)
box_header = segment.read(8)
box_header = mp4_bytes[index : index + 8]
if box_header[4:8] == target_type:
yield index
segment.seek(index)
index += int.from_bytes(box_header[0:4], byteorder="big")
def get_init(segment: io.BytesIO) -> bytes:
"""Get init section from fragmented mp4."""
moof_location = next(find_box(segment, b"moof"))
segment.seek(0)
return segment.read(moof_location)
def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]:
"""Get the init and moof data from a segment."""
moof_location = next(find_box(segment, b"moof"), 0)
mfra_location = next(find_box(segment, b"mfra"), len(segment))
return (
segment[:moof_location].tobytes(),
segment[moof_location:mfra_location].tobytes(),
)
def get_m4s(segment: io.BytesIO, sequence: int) -> bytes:
"""Get m4s section from fragmented mp4."""
moof_location = next(find_box(segment, b"moof"))
mfra_location = next(find_box(segment, b"mfra"))
segment.seek(moof_location)
return segment.read(mfra_location - moof_location)
def get_codec_string(segment: io.BytesIO) -> str:
def get_codec_string(mp4_bytes: bytes) -> str:
"""Get RFC 6381 codec string."""
codecs = []
# Find moov
moov_location = next(find_box(segment, b"moov"))
moov_location = next(find_box(mp4_bytes, b"moov"))
# Find tracks
for trak_location in find_box(segment, b"trak", moov_location):
for trak_location in find_box(mp4_bytes, b"trak", moov_location):
# Drill down to media info
mdia_location = next(find_box(segment, b"mdia", trak_location))
minf_location = next(find_box(segment, b"minf", mdia_location))
stbl_location = next(find_box(segment, b"stbl", minf_location))
stsd_location = next(find_box(segment, b"stsd", stbl_location))
mdia_location = next(find_box(mp4_bytes, b"mdia", trak_location))
minf_location = next(find_box(mp4_bytes, b"minf", mdia_location))
stbl_location = next(find_box(mp4_bytes, b"stbl", minf_location))
stsd_location = next(find_box(mp4_bytes, b"stsd", stbl_location))
# Get stsd box
segment.seek(stsd_location)
stsd_length = int.from_bytes(segment.read(4), byteorder="big")
segment.seek(stsd_location)
stsd_box = segment.read(stsd_length)
stsd_length = int.from_bytes(
mp4_bytes[stsd_location : stsd_location + 4], byteorder="big"
)
stsd_box = mp4_bytes[stsd_location : stsd_location + stsd_length]
# Base Codec
codec = stsd_box[20:24].decode("utf-8")

View File

@ -1,13 +1,11 @@
"""Provide functionality to stream HLS."""
import io
from aiohttp import web
from homeassistant.core import callback
from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS
from .core import PROVIDERS, HomeAssistant, IdleTimer, StreamOutput, StreamView
from .fmp4utils import get_codec_string, get_init, get_m4s
from .fmp4utils import get_codec_string
@callback
@ -35,9 +33,9 @@ class HlsMasterPlaylistView(StreamView):
# hls spec already allows for 25% variation
segment = track.get_segment(track.segments[-1])
bandwidth = round(
segment.segment.seek(0, io.SEEK_END) * 8 / segment.duration * 1.2
(len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2
)
codecs = get_codec_string(segment.segment)
codecs = get_codec_string(segment.init)
lines = [
"#EXTM3U",
f'#EXT-X-STREAM-INF:BANDWIDTH={bandwidth},CODECS="{codecs}"',
@ -129,7 +127,7 @@ class HlsInitView(StreamView):
if not segments:
return web.HTTPNotFound()
headers = {"Content-Type": "video/mp4"}
return web.Response(body=get_init(segments[0].segment), headers=headers)
return web.Response(body=segments[0].init, headers=headers)
class HlsSegmentView(StreamView):
@ -147,7 +145,7 @@ class HlsSegmentView(StreamView):
return web.HTTPNotFound()
headers = {"Content-Type": "video/iso.segment"}
return web.Response(
body=get_m4s(segment.segment, int(sequence)),
body=segment.moof_data,
headers=headers,
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections import deque
from io import BytesIO
import logging
import os
import threading
@ -51,7 +52,11 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]):
last_sequence = segment.sequence
# Open segment
source = av.open(segment.segment, "r", format=SEGMENT_CONTAINER_FORMAT)
source = av.open(
BytesIO(segment.init + segment.moof_data),
"r",
format=SEGMENT_CONTAINER_FORMAT,
)
source_v = source.streams.video[0]
source_a = source.streams.audio[0] if len(source.streams.audio) > 0 else None

View File

@ -19,6 +19,7 @@ from .const import (
STREAM_TIMEOUT,
)
from .core import Segment, StreamOutput
from .fmp4utils import get_init_and_moof_data
_LOGGER = logging.getLogger(__name__)
@ -29,8 +30,6 @@ class SegmentBuffer:
def __init__(self, outputs_callback) -> None:
"""Initialize SegmentBuffer."""
self._stream_id = 0
self._video_stream = None
self._audio_stream = None
self._outputs_callback = outputs_callback
self._outputs: list[StreamOutput] = []
self._sequence = 0
@ -41,10 +40,11 @@ class SegmentBuffer:
self._input_audio_stream = None # av.audio.AudioStream | None
self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream = None # av.audio.AudioStream | None
self._segment: Segment = cast(Segment, None)
@staticmethod
def make_new_av(
memory_file, sequence: int, input_vstream: av.video.VideoStream
memory_file: BytesIO, sequence: int, input_vstream: av.video.VideoStream
) -> av.container.OutputContainer:
"""Make a new av OutputContainer."""
return av.open(
@ -120,7 +120,13 @@ class SegmentBuffer:
def flush(self, duration):
"""Create a segment from the buffered packets and write to output."""
self._av_output.close()
segment = Segment(self._sequence, self._memory_file, duration, self._stream_id)
segment = Segment(
self._sequence,
*get_init_and_moof_data(self._memory_file.getbuffer()),
duration,
self._stream_id,
)
self._memory_file.close()
for stream_output in self._outputs:
stream_output.put(segment)
@ -134,6 +140,7 @@ class SegmentBuffer:
def close(self):
"""Close stream buffer."""
self._av_output.close()
self._memory_file.close()
def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901

View File

@ -1,6 +1,5 @@
"""The tests for hls streams."""
from datetime import timedelta
import io
from unittest.mock import patch
from urllib.parse import urlparse
@ -18,7 +17,8 @@ from tests.common import async_fire_time_changed
from tests.components.stream.common import generate_h264_video
STREAM_SOURCE = "some-stream-source"
SEQUENCE_BYTES = io.BytesIO(b"some-bytes")
INIT_BYTES = b"init"
MOOF_BYTES = b"some-bytes"
DURATION = 10
TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
@ -248,7 +248,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream_worker_sync.pause()
hls = stream.add_provider("hls")
hls.put(Segment(1, SEQUENCE_BYTES, DURATION))
hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION))
await hass.async_block_till_done()
hls_client = await hls_stream(stream)
@ -257,7 +257,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
assert resp.status == 200
assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)])
hls.put(Segment(2, SEQUENCE_BYTES, DURATION))
hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, DURATION))
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
@ -281,7 +281,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
# Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION))
hls.put(Segment(sequence, INIT_BYTES, MOOF_BYTES, DURATION))
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
@ -297,10 +297,6 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
segments=segments,
)
# Fetch the actual segments with a fake byte payload
with patch(
"homeassistant.components.stream.hls.get_m4s", return_value=b"fake-payload"
):
# The segment that fell off the buffer is not accessible
segment_response = await hls_client.get("/segment/1.m4s")
assert segment_response.status == 404
@ -322,9 +318,9 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
stream_worker_sync.pause()
hls = stream.add_provider("hls")
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0))
hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0))
hls.put(Segment(3, SEQUENCE_BYTES, DURATION, stream_id=1))
hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0))
hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0))
hls.put(Segment(3, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=1))
await hass.async_block_till_done()
hls_client = await hls_stream(stream)
@ -354,11 +350,11 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
hls_client = await hls_stream(stream)
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0))
hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=0))
# Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION, stream_id=1))
hls.put(Segment(sequence, INIT_BYTES, MOOF_BYTES, DURATION, stream_id=1))
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")

View File

@ -1,10 +1,13 @@
"""The tests for hls streams."""
from __future__ import annotations
import asyncio
from collections import deque
from datetime import timedelta
from io import BytesIO
import logging
import os
import threading
from typing import Deque
from unittest.mock import patch
import async_timeout
@ -13,6 +16,7 @@ import pytest
from homeassistant.components.stream import create_stream
from homeassistant.components.stream.core import Segment
from homeassistant.components.stream.fmp4utils import get_init_and_moof_data
from homeassistant.components.stream.recorder import recorder_save_worker
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
@ -37,8 +41,9 @@ class SaveRecordWorkerSync:
"""Initialize SaveRecordWorkerSync."""
self.reset()
self._segments = None
self._save_thread = None
def recorder_save_worker(self, file_out: str, segments: Deque[Segment]):
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
assert self._save_thread is None
@ -180,7 +185,9 @@ async def test_recorder_save(tmpdir):
filename = f"{tmpdir}/test.mp4"
# Run
recorder_save_worker(filename, [Segment(1, source, 4)])
recorder_save_worker(
filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)]
)
# Assert
assert os.path.exists(filename)
@ -193,13 +200,20 @@ async def test_recorder_discontinuity(tmpdir):
filename = f"{tmpdir}/test.mp4"
# Run
recorder_save_worker(filename, [Segment(1, source, 4, 0), Segment(2, source, 4, 1)])
init, moof_data = get_init_and_moof_data(source.getbuffer())
recorder_save_worker(
filename,
[
Segment(1, init, moof_data, 4, 0),
Segment(2, init, moof_data, 4, 1),
],
)
# Assert
assert os.path.exists(filename)
async def test_recorder_no_segements(tmpdir):
async def test_recorder_no_segments(tmpdir):
"""Test recorder behavior with a stream failure which causes no segments."""
# Setup
filename = f"{tmpdir}/test.mp4"
@ -247,7 +261,9 @@ async def test_record_stream_audio(
last_segment = segment
stream_worker_sync.resume()
result = av.open(last_segment.segment, "r", format="mp4")
result = av.open(
BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4"
)
assert len(result.streams.audio) == expected_audio_streams
result.close()