Add discontinuity support to HLS streams and fix nest expiring stream urls (#46683)

* Support HLS stream discontinuity.

* Clarify discontinuity comments

* Signal a stream discontinuity on restart due to stream error

* Apply suggestions from code review

Co-authored-by: uvjustin <46082645+uvjustin@users.noreply.github.com>

* Simplify stream discontinuity logic
This commit is contained in:
Allen Porter 2021-02-18 04:26:02 -08:00 committed by GitHub
parent 62cfe24ed4
commit 88d143a644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 29 deletions

View File

@ -170,7 +170,7 @@ class Stream:
def update_source(self, new_source):
"""Restart the stream with a new stream source."""
_LOGGER.debug("Updating stream source %s", self.source)
_LOGGER.debug("Updating stream source %s", new_source)
self.source = new_source
self._fast_restart_once = True
self._thread_quit.set()
@ -179,12 +179,14 @@ class Stream:
"""Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel
from .worker import stream_worker
from .worker import SegmentBuffer, stream_worker
segment_buffer = SegmentBuffer(self.outputs)
wait_timeout = 0
while not self._thread_quit.wait(timeout=wait_timeout):
start_time = time.time()
stream_worker(self.source, self.options, self.outputs, self._thread_quit)
stream_worker(self.source, self.options, segment_buffer, self._thread_quit)
segment_buffer.discontinuity()
if not self.keepalive or self._thread_quit.is_set():
if self._fast_restart_once:
# The stream source is updated, restart without any delay.

View File

@ -30,6 +30,8 @@ class Segment:
sequence: int = attr.ib()
segment: io.BytesIO = attr.ib()
duration: float = attr.ib()
# For detecting discontinuities across stream restarts
stream_id: int = attr.ib(default=0)
class IdleTimer:

View File

@ -78,21 +78,27 @@ class HlsPlaylistView(StreamView):
@staticmethod
def render_playlist(track):
"""Render playlist."""
segments = track.segments[-NUM_PLAYLIST_SEGMENTS:]
segments = list(track.get_segment())[-NUM_PLAYLIST_SEGMENTS:]
if not segments:
return []
playlist = ["#EXT-X-MEDIA-SEQUENCE:{}".format(segments[0])]
playlist = [
"#EXT-X-MEDIA-SEQUENCE:{}".format(segments[0].sequence),
"#EXT-X-DISCONTINUITY-SEQUENCE:{}".format(segments[0].stream_id),
]
for sequence in segments:
segment = track.get_segment(sequence)
last_stream_id = segments[0].stream_id
for segment in segments:
if last_stream_id != segment.stream_id:
playlist.append("#EXT-X-DISCONTINUITY")
playlist.extend(
[
"#EXTINF:{:.04f},".format(float(segment.duration)),
f"./segment/{segment.sequence}.m4s",
]
)
last_stream_id = segment.stream_id
return playlist

View File

@ -49,16 +49,22 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment."""
def __init__(self, video_stream, audio_stream, outputs_callback) -> None:
def __init__(self, outputs_callback) -> None:
"""Initialize SegmentBuffer."""
self._video_stream = video_stream
self._audio_stream = audio_stream
self._stream_id = 0
self._video_stream = None
self._audio_stream = None
self._outputs_callback = outputs_callback
# tuple of StreamOutput, StreamBuffer
self._outputs = []
self._sequence = 0
self._segment_start_pts = None
def set_streams(self, video_stream, audio_stream):
"""Initialize output buffer with streams from container."""
self._video_stream = video_stream
self._audio_stream = audio_stream
def reset(self, video_pts):
"""Initialize a new stream segment."""
# Keep track of the number of segments we've processed
@ -103,7 +109,16 @@ class SegmentBuffer:
"""Create a segment from the buffered packets and write to output."""
for (buffer, stream_output) in self._outputs:
buffer.output.close()
stream_output.put(Segment(self._sequence, buffer.segment, duration))
stream_output.put(
Segment(self._sequence, buffer.segment, duration, self._stream_id)
)
def discontinuity(self):
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
def close(self):
"""Close all StreamBuffers."""
@ -111,7 +126,7 @@ class SegmentBuffer:
buffer.output.close()
def stream_worker(source, options, outputs_callback, quit_event):
def stream_worker(source, options, segment_buffer, quit_event):
"""Handle consuming streams."""
try:
@ -143,8 +158,6 @@ def stream_worker(source, options, outputs_callback, quit_event):
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
# Keep track of consecutive packets without a dts to detect end of stream.
missing_dts = 0
# Holds the buffers for each stream provider
segment_buffer = SegmentBuffer(video_stream, audio_stream, outputs_callback)
# The video pts at the beginning of the segment
segment_start_pts = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
@ -225,6 +238,7 @@ def stream_worker(source, options, outputs_callback, quit_event):
container.close()
return
segment_buffer.set_streams(video_stream, audio_stream)
segment_buffer.reset(segment_start_pts)
while not quit_event.is_set():

View File

@ -51,7 +51,16 @@ def hls_stream(hass, hass_client):
return create_client_for_stream
def playlist_response(sequence, segments):
def make_segment(segment, discontinuity=False):
"""Create a playlist response for a segment."""
response = []
if discontinuity:
response.append("#EXT-X-DISCONTINUITY")
response.extend(["#EXTINF:10.0000,", f"./segment/{segment}.m4s"]),
return "\n".join(response)
def make_playlist(sequence, discontinuity_sequence=0, segments=[]):
"""Create a an hls playlist response for tests to assert on."""
response = [
"#EXTM3U",
@ -59,14 +68,9 @@ def playlist_response(sequence, segments):
"#EXT-X-TARGETDURATION:10",
'#EXT-X-MAP:URI="init.mp4"',
f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}",
]
for segment in segments:
response.extend(
[
"#EXTINF:10.0000,",
f"./segment/{segment}.m4s",
]
)
response.extend(segments)
response.append("")
return "\n".join(response)
@ -289,13 +293,15 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
assert await resp.text() == playlist_response(sequence=1, segments=[1])
assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)])
hls.put(Segment(2, SEQUENCE_BYTES, DURATION))
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
assert await resp.text() == playlist_response(sequence=1, segments=[1, 2])
assert await resp.text() == make_playlist(
sequence=1, segments=[make_segment(1), make_segment(2)]
)
stream_worker_sync.resume()
stream.stop()
@ -321,8 +327,12 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist.
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS
assert await resp.text() == playlist_response(
sequence=start, segments=range(start, MAX_SEGMENTS + 2)
segments = []
for sequence in range(start, MAX_SEGMENTS + 2):
segments.append(make_segment(sequence))
assert await resp.text() == make_playlist(
sequence=start,
segments=segments,
)
# Fetch the actual segments with a fake byte payload
@ -340,3 +350,70 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
stream_worker_sync.resume()
stream.stop()
async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_sync):
"""Test a discontinuity across segments in the stream with 3 segments."""
await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE)
stream_worker_sync.pause()
hls = stream.hls_output()
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0))
hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0))
hls.put(Segment(3, SEQUENCE_BYTES, DURATION, stream_id=1))
await hass.async_block_till_done()
hls_client = await hls_stream(stream)
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
assert await resp.text() == make_playlist(
sequence=1,
segments=[
make_segment(1),
make_segment(2),
make_segment(3, discontinuity=True),
],
)
stream_worker_sync.resume()
stream.stop()
async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sync):
"""Test a discontinuity with more segments than the segment deque can hold."""
await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE)
stream_worker_sync.pause()
hls = stream.hls_output()
hls_client = await hls_stream(stream)
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0))
# Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION, stream_id=1))
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the
# EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE
# returned instead.
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS
segments = []
for sequence in range(start, MAX_SEGMENTS + 2):
segments.append(make_segment(sequence))
assert await resp.text() == make_playlist(
sequence=start,
discontinuity_sequence=1,
segments=segments,
)
stream_worker_sync.resume()
stream.stop()

View File

@ -27,7 +27,7 @@ from homeassistant.components.stream.const import (
MIN_SEGMENT_DURATION,
PACKETS_TO_WAIT_FOR_AUDIO,
)
from homeassistant.components.stream.worker import stream_worker
from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
STREAM_SOURCE = "some-stream-source"
# Formats here are arbitrary, not exercised by tests
@ -197,7 +197,8 @@ async def async_decode_stream(hass, packets, py_av=None):
"homeassistant.components.stream.core.StreamOutput.put",
side_effect=py_av.capture_buffer.capture_output_segment,
):
stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event())
segment_buffer = SegmentBuffer(stream.outputs)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
await hass.async_block_till_done()
return py_av.capture_buffer
@ -209,7 +210,8 @@ async def test_stream_open_fails(hass):
stream.hls_output()
with patch("av.open") as av_open:
av_open.side_effect = av.error.InvalidDataError(-2, "error")
stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event())
segment_buffer = SegmentBuffer(stream.outputs)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
await hass.async_block_till_done()
av_open.assert_called_once()