From 094f7d38ad5b30f0dbedeea4492b7a85f3c0c228 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 15 Aug 2021 21:02:37 -0700 Subject: [PATCH] Use buffer at stream start with unsupported audio (#54672) Add a test that reproduces the issue where resetting the iterator drops the already read packets. Fix a bug in replace_underlying_iterator because checking the self._next function turns out not to work since it points to a bound method so the "is not" check fails. --- homeassistant/components/stream/worker.py | 2 +- tests/components/stream/test_worker.py | 52 ++++++++++++++--------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 69def43b2a2..039163c6cf5 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -225,7 +225,7 @@ class PeekIterator(Iterator): def replace_underlying_iterator(self, new_iterator: Iterator) -> None: """Replace the underlying iterator while preserving the buffer.""" self._iterator = new_iterator - if self._next is not self._pop_buffer: + if not self._buffer: self._next = self._iterator.__next__ def _pop_buffer(self) -> av.Packet: diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index ffbeb44d79e..e62a190d7be 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -15,6 +15,7 @@ failure modes or corner cases like how out of order packets are handled. import fractions import io +import logging import math import threading from unittest.mock import patch @@ -52,7 +53,7 @@ SEGMENTS_PER_PACKET = PACKET_DURATION / SEGMENT_DURATION TIMEOUT = 15 -class FakePyAvStream: +class FakeAvInputStream: """A fake pyav Stream.""" def __init__(self, name, rate): @@ -66,9 +67,13 @@ class FakePyAvStream: self.codec = FakeCodec() + def __str__(self) -> str: + """Return a stream name for debugging.""" + return f"FakePyAvStream<{self.name}, {self.time_base}>" -VIDEO_STREAM = FakePyAvStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE) -AUDIO_STREAM = FakePyAvStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE) + +VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE) +AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE) class PacketSequence: @@ -110,6 +115,9 @@ class PacketSequence: is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL) size = 3 + def __str__(self) -> str: + return f"FakePacket" + return FakePacket() @@ -154,7 +162,7 @@ class FakePyAvBuffer: def add_stream(self, template=None): """Create an output buffer that captures packets for test to examine.""" - class FakeStream: + class FakeAvOutputStream: def __init__(self, capture_packets): self.capture_packets = capture_packets @@ -162,11 +170,15 @@ class FakePyAvBuffer: return def mux(self, packet): + logging.debug("Muxed packet: %s", packet) self.capture_packets.append(packet) + def __str__(self) -> str: + return f"FakeAvOutputStream<{template.name}>" + if template.name == AUDIO_STREAM_FORMAT: - return FakeStream(self.audio_packets) - return FakeStream(self.video_packets) + return FakeAvOutputStream(self.audio_packets) + return FakeAvOutputStream(self.video_packets) def mux(self, packet): """Capture a packet for tests to examine.""" @@ -217,7 +229,7 @@ async def async_decode_stream(hass, packets, py_av=None): if not py_av: py_av = MockPyAv() - py_av.container.packets = packets + py_av.container.packets = iter(packets) # Can't be rewound with patch("av.open", new=py_av.open), patch( "homeassistant.components.stream.core.StreamOutput.put", @@ -273,7 +285,7 @@ async def test_skip_out_of_order_packet(hass): assert not packets[out_of_order_index].is_keyframe packets[out_of_order_index].dts = -9090 - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check sequence numbers @@ -309,7 +321,7 @@ async def test_discard_old_packets(hass): # Packets after this one are considered out of order packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090 - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check number of segments @@ -331,7 +343,7 @@ async def test_packet_overflow(hass): # Packet is so far out of order, exceeds max gap and looks like overflow packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check number of segments @@ -355,7 +367,7 @@ async def test_skip_initial_bad_packets(hass): for i in range(0, num_bad_packets): packets[i].dts = None - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check sequence numbers @@ -385,7 +397,7 @@ async def test_too_many_initial_bad_packets_fails(hass): for i in range(0, num_bad_packets): packets[i].dts = None - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments assert len(segments) == 0 assert len(decoded_stream.video_packets) == 0 @@ -405,7 +417,7 @@ async def test_skip_missing_dts(hass): continue packets[i].dts = None - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check sequence numbers @@ -426,7 +438,7 @@ async def test_too_many_bad_packets(hass): for i in range(bad_packet_start, bad_packet_start + num_bad_packets): packets[i].dts = None - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) complete_segments = decoded_stream.complete_segments assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == bad_packet_start @@ -454,7 +466,7 @@ async def test_audio_packets_not_found(hass): num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 packets = PacketSequence(num_packets) # Contains only video packets - decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) + decoded_stream = await async_decode_stream(hass, packets, py_av=py_av) complete_segments = decoded_stream.complete_segments assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == num_packets @@ -474,8 +486,10 @@ async def test_adts_aac_audio(hass): packets[1][0] = 255 packets[1][1] = 241 - decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) + decoded_stream = await async_decode_stream(hass, packets, py_av=py_av) assert len(decoded_stream.audio_packets) == 0 + # All decoded video packets are still preserved + assert len(decoded_stream.video_packets) == num_packets - 1 async def test_audio_is_first_packet(hass): @@ -493,7 +507,7 @@ async def test_audio_is_first_packet(hass): packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) - decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) + decoded_stream = await async_decode_stream(hass, packets, py_av=py_av) complete_segments = decoded_stream.complete_segments # The audio packets are segmented with the video packets assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) @@ -511,7 +525,7 @@ async def test_audio_packets_found(hass): packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) - decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) + decoded_stream = await async_decode_stream(hass, packets, py_av=py_av) complete_segments = decoded_stream.complete_segments # The audio packet above is buffered with the video packet assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) @@ -529,7 +543,7 @@ async def test_pts_out_of_order(hass): packets[i].pts = packets[i - 1].pts - 1 packets[i].is_keyframe = False - decoded_stream = await async_decode_stream(hass, iter(packets)) + decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments complete_segments = decoded_stream.complete_segments # Check number of segments