Centralize keepalive logic in Stream class (#45850)

* Remove dependencies on keepalive from StremaOutput and stream_worker

Pull logic from StreamOutput and stream_worker into the Stream
class, unifying keepalive and idle timeout logic. This prepares
for future changes to preserve hls state across stream url changes.
This commit is contained in:
Allen Porter 2021-02-08 07:19:41 -08:00 committed by GitHub
parent e20a814926
commit dca6a93898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 142 additions and 139 deletions

View File

@ -2,6 +2,7 @@
import logging import logging
import secrets import secrets
import threading import threading
import time
from types import MappingProxyType from types import MappingProxyType
import voluptuous as vol import voluptuous as vol
@ -20,9 +21,12 @@ from .const import (
CONF_STREAM_SOURCE, CONF_STREAM_SOURCE,
DOMAIN, DOMAIN,
MAX_SEGMENTS, MAX_SEGMENTS,
OUTPUT_IDLE_TIMEOUT,
SERVICE_RECORD, SERVICE_RECORD,
STREAM_RESTART_INCREMENT,
STREAM_RESTART_RESET_TIME,
) )
from .core import PROVIDERS from .core import PROVIDERS, IdleTimer
from .hls import async_setup_hls from .hls import async_setup_hls
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -142,18 +146,27 @@ class Stream:
# without concern about self._outputs being modified from another thread. # without concern about self._outputs being modified from another thread.
return MappingProxyType(self._outputs.copy()) return MappingProxyType(self._outputs.copy())
def add_provider(self, fmt): def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT):
"""Add provider output stream.""" """Add provider output stream."""
if not self._outputs.get(fmt): if not self._outputs.get(fmt):
provider = PROVIDERS[fmt](self)
@callback
def idle_callback():
if not self.keepalive and fmt in self._outputs:
self.remove_provider(self._outputs[fmt])
self.check_idle()
provider = PROVIDERS[fmt](
self.hass, IdleTimer(self.hass, timeout, idle_callback)
)
self._outputs[fmt] = provider self._outputs[fmt] = provider
return self._outputs[fmt] return self._outputs[fmt]
def remove_provider(self, provider): def remove_provider(self, provider):
"""Remove provider output stream.""" """Remove provider output stream."""
if provider.name in self._outputs: if provider.name in self._outputs:
self._outputs[provider.name].cleanup()
del self._outputs[provider.name] del self._outputs[provider.name]
self.check_idle()
if not self._outputs: if not self._outputs:
self.stop() self.stop()
@ -165,10 +178,6 @@ class Stream:
def start(self): def start(self):
"""Start a stream.""" """Start a stream."""
# Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel
from .worker import stream_worker
if self._thread is None or not self._thread.is_alive(): if self._thread is None or not self._thread.is_alive():
if self._thread is not None: if self._thread is not None:
# The thread must have crashed/exited. Join to clean up the # The thread must have crashed/exited. Join to clean up the
@ -177,12 +186,48 @@ class Stream:
self._thread_quit = threading.Event() self._thread_quit = threading.Event()
self._thread = threading.Thread( self._thread = threading.Thread(
name="stream_worker", name="stream_worker",
target=stream_worker, target=self._run_worker,
args=(self.hass, self, self._thread_quit),
) )
self._thread.start() self._thread.start()
_LOGGER.info("Started stream: %s", self.source) _LOGGER.info("Started stream: %s", self.source)
def _run_worker(self):
"""Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel
from .worker import stream_worker
wait_timeout = 0
while not self._thread_quit.wait(timeout=wait_timeout):
start_time = time.time()
stream_worker(self.hass, self, self._thread_quit)
if not self.keepalive or self._thread_quit.is_set():
break
# To avoid excessive restarts, wait before restarting
# As the required recovery time may be different for different setups, start
# with trying a short wait_timeout and increase it on each reconnection attempt.
# Reset the wait_timeout after the worker has been up for several minutes
if time.time() - start_time > STREAM_RESTART_RESET_TIME:
wait_timeout = 0
wait_timeout += STREAM_RESTART_INCREMENT
_LOGGER.debug(
"Restarting stream worker in %d seconds: %s",
wait_timeout,
self.source,
)
self._worker_finished()
def _worker_finished(self):
"""Schedule cleanup of all outputs."""
@callback
def remove_outputs():
for provider in self.outputs.values():
self.remove_provider(provider)
self.hass.loop.call_soon_threadsafe(remove_outputs)
def stop(self): def stop(self):
"""Remove outputs and access token.""" """Remove outputs and access token."""
self._outputs = {} self._outputs = {}
@ -223,9 +268,8 @@ async def async_handle_record_service(hass, call):
if recorder: if recorder:
raise HomeAssistantError(f"Stream already recording to {recorder.video_path}!") raise HomeAssistantError(f"Stream already recording to {recorder.video_path}!")
recorder = stream.add_provider("recorder") recorder = stream.add_provider("recorder", timeout=duration)
recorder.video_path = video_path recorder.video_path = video_path
recorder.timeout = duration
stream.start() stream.start()

View File

@ -15,6 +15,8 @@ OUTPUT_FORMATS = ["hls"]
FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"} FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"}
OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity
MAX_SEGMENTS = 3 # Max number of segments to keep around MAX_SEGMENTS = 3 # Max number of segments to keep around
MIN_SEGMENT_DURATION = 1.5 # Each segment is at least this many seconds MIN_SEGMENT_DURATION = 1.5 # Each segment is at least this many seconds

View File

@ -8,7 +8,7 @@ from aiohttp import web
import attr import attr
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
@ -36,24 +36,69 @@ class Segment:
duration: float = attr.ib() duration: float = attr.ib()
class IdleTimer:
"""Invoke a callback after an inactivity timeout.
The IdleTimer invokes the callback after some timeout has passed. The awake() method
resets the internal alarm, extending the inactivity time.
"""
def __init__(
self, hass: HomeAssistant, timeout: int, idle_callback: Callable[[], None]
):
"""Initialize IdleTimer."""
self._hass = hass
self._timeout = timeout
self._callback = idle_callback
self._unsub = None
self.idle = False
def start(self):
"""Start the idle timer if not already started."""
self.idle = False
if self._unsub is None:
self._unsub = async_call_later(self._hass, self._timeout, self.fire)
def awake(self):
"""Keep the idle time alive by resetting the timeout."""
self.idle = False
# Reset idle timeout
self.clear()
self._unsub = async_call_later(self._hass, self._timeout, self.fire)
def clear(self):
"""Clear and disable the timer."""
if self._unsub is not None:
self._unsub()
def fire(self, _now=None):
"""Invoke the idle timeout callback, called when the alarm fires."""
self.idle = True
self._unsub = None
self._callback()
class StreamOutput: class StreamOutput:
"""Represents a stream output.""" """Represents a stream output."""
def __init__(self, stream, timeout: int = 300) -> None: def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
"""Initialize a stream output.""" """Initialize a stream output."""
self.idle = False self._hass = hass
self.timeout = timeout self._idle_timer = idle_timer
self._stream = stream
self._cursor = None self._cursor = None
self._event = asyncio.Event() self._event = asyncio.Event()
self._segments = deque(maxlen=MAX_SEGMENTS) self._segments = deque(maxlen=MAX_SEGMENTS)
self._unsub = None
@property @property
def name(self) -> str: def name(self) -> str:
"""Return provider name.""" """Return provider name."""
return None return None
@property
def idle(self) -> bool:
"""Return True if the output is idle."""
return self._idle_timer.idle
@property @property
def format(self) -> str: def format(self) -> str:
"""Return container format.""" """Return container format."""
@ -90,11 +135,7 @@ class StreamOutput:
def get_segment(self, sequence: int = None) -> Any: def get_segment(self, sequence: int = None) -> Any:
"""Retrieve a specific segment, or the whole list.""" """Retrieve a specific segment, or the whole list."""
self.idle = False self._idle_timer.awake()
# Reset idle timeout
if self._unsub is not None:
self._unsub()
self._unsub = async_call_later(self._stream.hass, self.timeout, self._timeout)
if not sequence: if not sequence:
return self._segments return self._segments
@ -119,43 +160,22 @@ class StreamOutput:
def put(self, segment: Segment) -> None: def put(self, segment: Segment) -> None:
"""Store output.""" """Store output."""
self._stream.hass.loop.call_soon_threadsafe(self._async_put, segment) self._hass.loop.call_soon_threadsafe(self._async_put, segment)
@callback @callback
def _async_put(self, segment: Segment) -> None: def _async_put(self, segment: Segment) -> None:
"""Store output from event loop.""" """Store output from event loop."""
# Start idle timeout when we start receiving data # Start idle timeout when we start receiving data
if self._unsub is None: self._idle_timer.start()
self._unsub = async_call_later(
self._stream.hass, self.timeout, self._timeout
)
if segment is None:
self._event.set()
# Cleanup provider
if self._unsub is not None:
self._unsub()
self.cleanup()
return
self._segments.append(segment) self._segments.append(segment)
self._event.set() self._event.set()
self._event.clear() self._event.clear()
@callback
def _timeout(self, _now=None):
"""Handle stream timeout."""
self._unsub = None
if self._stream.keepalive:
self.idle = True
self._stream.check_idle()
else:
self.cleanup()
def cleanup(self): def cleanup(self):
"""Handle cleanup.""" """Handle cleanup."""
self._event.set()
self._idle_timer.clear()
self._segments = deque(maxlen=MAX_SEGMENTS) self._segments = deque(maxlen=MAX_SEGMENTS)
self._stream.remove_provider(self)
class StreamView(HomeAssistantView): class StreamView(HomeAssistantView):

View File

@ -6,9 +6,9 @@ from typing import List
import av import av
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from .core import PROVIDERS, Segment, StreamOutput from .core import PROVIDERS, IdleTimer, Segment, StreamOutput
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -72,9 +72,9 @@ def recorder_save_worker(file_out: str, segments: List[Segment], container_forma
class RecorderOutput(StreamOutput): class RecorderOutput(StreamOutput):
"""Represents HLS Output formats.""" """Represents HLS Output formats."""
def __init__(self, stream, timeout: int = 30) -> None: def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
"""Initialize recorder output.""" """Initialize recorder output."""
super().__init__(stream, timeout) super().__init__(hass, idle_timer)
self.video_path = None self.video_path = None
self._segments = [] self._segments = []
@ -104,12 +104,6 @@ class RecorderOutput(StreamOutput):
segments = [s for s in segments if s.sequence not in own_segments] segments = [s for s in segments if s.sequence not in own_segments]
self._segments = segments + self._segments self._segments = segments + self._segments
@callback
def _timeout(self, _now=None):
"""Handle recorder timeout."""
self._unsub = None
self.cleanup()
def cleanup(self): def cleanup(self):
"""Write recording and clean up.""" """Write recording and clean up."""
_LOGGER.debug("Starting recorder worker thread") _LOGGER.debug("Starting recorder worker thread")
@ -120,5 +114,5 @@ class RecorderOutput(StreamOutput):
) )
thread.start() thread.start()
super().cleanup()
self._segments = [] self._segments = []
self._stream.remove_provider(self)

View File

@ -2,7 +2,6 @@
from collections import deque from collections import deque
import io import io
import logging import logging
import time
import av import av
@ -11,8 +10,6 @@ from .const import (
MAX_TIMESTAMP_GAP, MAX_TIMESTAMP_GAP,
MIN_SEGMENT_DURATION, MIN_SEGMENT_DURATION,
PACKETS_TO_WAIT_FOR_AUDIO, PACKETS_TO_WAIT_FOR_AUDIO,
STREAM_RESTART_INCREMENT,
STREAM_RESTART_RESET_TIME,
STREAM_TIMEOUT, STREAM_TIMEOUT,
) )
from .core import Segment, StreamBuffer from .core import Segment, StreamBuffer
@ -47,32 +44,6 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
def stream_worker(hass, stream, quit_event): def stream_worker(hass, stream, quit_event):
"""Handle consuming streams and restart keepalive streams."""
wait_timeout = 0
while not quit_event.wait(timeout=wait_timeout):
start_time = time.time()
try:
_stream_worker_internal(hass, stream, quit_event)
except av.error.FFmpegError: # pylint: disable=c-extension-no-member
_LOGGER.exception("Stream connection failed: %s", stream.source)
if not stream.keepalive or quit_event.is_set():
break
# To avoid excessive restarts, wait before restarting
# As the required recovery time may be different for different setups, start
# with trying a short wait_timeout and increase it on each reconnection attempt.
# Reset the wait_timeout after the worker has been up for several minutes
if time.time() - start_time > STREAM_RESTART_RESET_TIME:
wait_timeout = 0
wait_timeout += STREAM_RESTART_INCREMENT
_LOGGER.debug(
"Restarting stream worker in %d seconds: %s",
wait_timeout,
stream.source,
)
def _stream_worker_internal(hass, stream, quit_event):
"""Handle consuming streams.""" """Handle consuming streams."""
try: try:
@ -183,7 +154,6 @@ def _stream_worker_internal(hass, stream, quit_event):
_LOGGER.error( _LOGGER.error(
"Error demuxing stream while finding first packet: %s", str(ex) "Error demuxing stream while finding first packet: %s", str(ex)
) )
finalize_stream()
return False return False
return True return True
@ -220,12 +190,6 @@ def _stream_worker_internal(hass, stream, quit_event):
packet.stream = output_streams[audio_stream] packet.stream = output_streams[audio_stream]
buffer.output.mux(packet) buffer.output.mux(packet)
def finalize_stream():
if not stream.keepalive:
# End of stream, clear listeners and stop thread
for fmt in stream.outputs:
stream.outputs[fmt].put(None)
if not peek_first_pts(): if not peek_first_pts():
container.close() container.close()
return return
@ -249,7 +213,6 @@ def _stream_worker_internal(hass, stream, quit_event):
missing_dts = 0 missing_dts = 0
except (av.AVError, StopIteration) as ex: except (av.AVError, StopIteration) as ex:
_LOGGER.error("Error demuxing stream: %s", str(ex)) _LOGGER.error("Error demuxing stream: %s", str(ex))
finalize_stream()
break break
# Discard packet if dts is not monotonic # Discard packet if dts is not monotonic
@ -263,7 +226,6 @@ def _stream_worker_internal(hass, stream, quit_event):
last_dts[packet.stream], last_dts[packet.stream],
packet.dts, packet.dts,
) )
finalize_stream()
break break
continue continue

View File

@ -9,14 +9,13 @@ nothing for the test to verify. The solution is the WorkerSync class that
allows the tests to pause the worker thread before finalizing the stream allows the tests to pause the worker thread before finalizing the stream
so that it can inspect the output. so that it can inspect the output.
""" """
import logging import logging
import threading import threading
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from homeassistant.components.stream.core import Segment, StreamOutput from homeassistant.components.stream import Stream
class WorkerSync: class WorkerSync:
@ -25,7 +24,7 @@ class WorkerSync:
def __init__(self): def __init__(self):
"""Initialize WorkerSync.""" """Initialize WorkerSync."""
self._event = None self._event = None
self._put_original = StreamOutput.put self._original = Stream._worker_finished
def pause(self): def pause(self):
"""Pause the worker before it finalizes the stream.""" """Pause the worker before it finalizes the stream."""
@ -35,17 +34,16 @@ class WorkerSync:
"""Allow the worker thread to finalize the stream.""" """Allow the worker thread to finalize the stream."""
self._event.set() self._event.set()
def blocking_put(self, stream_output: StreamOutput, segment: Segment): def blocking_finish(self, stream: Stream):
"""Proxy StreamOutput.put, intercepted for test to pause worker.""" """Intercept call to pause stream worker."""
if segment is None and self._event:
# Worker is ending the stream, which clears all output buffers. # Worker is ending the stream, which clears all output buffers.
# Block the worker thread until the test has a chance to verify # Block the worker thread until the test has a chance to verify
# the segments under test. # the segments under test.
logging.error("blocking worker") logging.debug("blocking worker")
self._event.wait() self._event.wait()
# Forward to actual StreamOutput.put # Forward to actual implementation
self._put_original(stream_output, segment) self._original(stream)
@pytest.fixture() @pytest.fixture()
@ -53,8 +51,8 @@ def stream_worker_sync(hass):
"""Patch StreamOutput to allow test to synchronize worker stream end.""" """Patch StreamOutput to allow test to synchronize worker stream end."""
sync = WorkerSync() sync = WorkerSync()
with patch( with patch(
"homeassistant.components.stream.core.StreamOutput.put", "homeassistant.components.stream.Stream._worker_finished",
side_effect=sync.blocking_put, side_effect=sync.blocking_finish,
autospec=True, autospec=True,
): ):
yield sync yield sync

View File

@ -98,6 +98,7 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
# Wait 5 minutes # Wait 5 minutes
future = dt_util.utcnow() + timedelta(minutes=5) future = dt_util.utcnow() + timedelta(minutes=5)
async_fire_time_changed(hass, future) async_fire_time_changed(hass, future)
await hass.async_block_till_done()
# Ensure playlist not accessible # Ensure playlist not accessible
fail_response = await http_client.get(parsed_url.path) fail_response = await http_client.get(parsed_url.path)
@ -155,9 +156,9 @@ async def test_stream_keepalive(hass):
return cur_time return cur_time
with patch("av.open") as av_open, patch( with patch("av.open") as av_open, patch(
"homeassistant.components.stream.worker.time" "homeassistant.components.stream.time"
) as mock_time, patch( ) as mock_time, patch(
"homeassistant.components.stream.worker.STREAM_RESTART_INCREMENT", 0 "homeassistant.components.stream.STREAM_RESTART_INCREMENT", 0
): ):
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
mock_time.time.side_effect = time_side_effect mock_time.time.side_effect = time_side_effect

View File

@ -80,5 +80,7 @@ async def test_record_service_lookback(hass):
await hass.services.async_call(DOMAIN, SERVICE_RECORD, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_RECORD, data, blocking=True)
assert stream_mock.called assert stream_mock.called
stream_mock.return_value.add_provider.assert_called_once_with("recorder") stream_mock.return_value.add_provider.assert_called_once_with(
"recorder", timeout=30
)
assert hls_mock.recv.called assert hls_mock.recv.called

View File

@ -106,13 +106,11 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
stream_worker_sync.pause() stream_worker_sync.pause()
with patch( with patch("homeassistant.components.stream.IdleTimer.fire") as mock_timeout:
"homeassistant.components.stream.recorder.RecorderOutput.cleanup"
) as mock_cleanup:
# Setup demo track # Setup demo track
source = generate_h264_video() source = generate_h264_video()
stream = preload_stream(hass, source) stream = preload_stream(hass, source)
recorder = stream.add_provider("recorder") recorder = stream.add_provider("recorder", timeout=30)
stream.start() stream.start()
await recorder.recv() await recorder.recv()
@ -122,7 +120,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
async_fire_time_changed(hass, future) async_fire_time_changed(hass, future)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_cleanup.called assert mock_timeout.called
stream_worker_sync.resume() stream_worker_sync.resume()
stream.stop() stream.stop()

View File

@ -132,7 +132,6 @@ class FakePyAvBuffer:
self.segments = [] self.segments = []
self.audio_packets = [] self.audio_packets = []
self.video_packets = [] self.video_packets = []
self.finished = False
def add_stream(self, template=None): def add_stream(self, template=None):
"""Create an output buffer that captures packets for test to examine.""" """Create an output buffer that captures packets for test to examine."""
@ -162,10 +161,6 @@ class FakePyAvBuffer:
def capture_output_segment(self, segment): def capture_output_segment(self, segment):
"""Capture the output segment for tests to inspect.""" """Capture the output segment for tests to inspect."""
assert not self.finished
if segment is None:
self.finished = True
else:
self.segments.append(segment) self.segments.append(segment)
@ -223,7 +218,6 @@ async def test_stream_worker_success(hass):
decoded_stream = await async_decode_stream( decoded_stream = await async_decode_stream(
hass, PacketSequence(TEST_SEQUENCE_LENGTH) hass, PacketSequence(TEST_SEQUENCE_LENGTH)
) )
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check number of segments. A segment is only formed when a packet from the next # Check number of segments. A segment is only formed when a packet from the next
# segment arrives, hence the subtraction of one from the sequence length. # segment arrives, hence the subtraction of one from the sequence length.
@ -243,7 +237,6 @@ async def test_skip_out_of_order_packet(hass):
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9090 packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9090
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check sequence numbers # Check sequence numbers
assert all([segments[i].sequence == i + 1 for i in range(len(segments))]) assert all([segments[i].sequence == i + 1 for i in range(len(segments))])
@ -279,7 +272,6 @@ async def test_discard_old_packets(hass):
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090 packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check number of segments # Check number of segments
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET)
@ -299,7 +291,6 @@ async def test_packet_overflow(hass):
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check number of segments # Check number of segments
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET)
@ -321,7 +312,6 @@ async def test_skip_initial_bad_packets(hass):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check number of segments # Check number of segments
assert len(segments) == int( assert len(segments) == int(
@ -345,7 +335,6 @@ async def test_too_many_initial_bad_packets_fails(hass):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == 0 assert len(segments) == 0
assert len(decoded_stream.video_packets) == 0 assert len(decoded_stream.video_packets) == 0
@ -363,7 +352,6 @@ async def test_skip_missing_dts(hass):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check sequence numbers # Check sequence numbers
assert all([segments[i].sequence == i + 1 for i in range(len(segments))]) assert all([segments[i].sequence == i + 1 for i in range(len(segments))])
@ -387,7 +375,6 @@ async def test_too_many_bad_packets(hass):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == bad_packet_start assert len(decoded_stream.video_packets) == bad_packet_start
@ -402,7 +389,6 @@ async def test_no_video_stream(hass):
hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av
) )
# Note: This failure scenario does not output an end of stream # Note: This failure scenario does not output an end of stream
assert not decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == 0 assert len(segments) == 0
assert len(decoded_stream.video_packets) == 0 assert len(decoded_stream.video_packets) == 0
@ -417,7 +403,6 @@ async def test_audio_packets_not_found(hass):
packets = PacketSequence(num_packets) # Contains only video packets packets = PacketSequence(num_packets) # Contains only video packets
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets assert len(decoded_stream.video_packets) == num_packets
@ -439,7 +424,6 @@ async def test_audio_is_first_packet(hass):
packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# The audio packets are segmented with the video packets # The audio packets are segmented with the video packets
assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
@ -458,7 +442,6 @@ async def test_audio_packets_found(hass):
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# The audio packet above is buffered with the video packet # The audio packet above is buffered with the video packet
assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
@ -477,7 +460,6 @@ async def test_pts_out_of_order(hass):
packets[i].is_keyframe = False packets[i].is_keyframe = False
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, iter(packets))
assert decoded_stream.finished
segments = decoded_stream.segments segments = decoded_stream.segments
# Check number of segments # Check number of segments
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET) assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET)