Use buffer at stream start with unsupported audio (#54672)

Add a test that reproduces the issue where resetting the iterator
drops the already read packets. Fix a bug in replace_underlying_iterator
because checking the self._next function turns out not to work since
it points to a bound method so the "is not" check fails.
This commit is contained in:
Allen Porter 2021-08-15 21:02:37 -07:00 committed by GitHub
parent bec42b74fe
commit 094f7d38ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 20 deletions

View File

@ -225,7 +225,7 @@ class PeekIterator(Iterator):
def replace_underlying_iterator(self, new_iterator: Iterator) -> None:
"""Replace the underlying iterator while preserving the buffer."""
self._iterator = new_iterator
if self._next is not self._pop_buffer:
if not self._buffer:
self._next = self._iterator.__next__
def _pop_buffer(self) -> av.Packet:

View File

@ -15,6 +15,7 @@ failure modes or corner cases like how out of order packets are handled.
import fractions
import io
import logging
import math
import threading
from unittest.mock import patch
@ -52,7 +53,7 @@ SEGMENTS_PER_PACKET = PACKET_DURATION / SEGMENT_DURATION
TIMEOUT = 15
class FakePyAvStream:
class FakeAvInputStream:
"""A fake pyav Stream."""
def __init__(self, name, rate):
@ -66,9 +67,13 @@ class FakePyAvStream:
self.codec = FakeCodec()
def __str__(self) -> str:
"""Return a stream name for debugging."""
return f"FakePyAvStream<{self.name}, {self.time_base}>"
VIDEO_STREAM = FakePyAvStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
AUDIO_STREAM = FakePyAvStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE)
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE)
class PacketSequence:
@ -110,6 +115,9 @@ class PacketSequence:
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
size = 3
def __str__(self) -> str:
return f"FakePacket<stream={self.stream}, pts={self.pts}, key={self.is_keyframe}>"
return FakePacket()
@ -154,7 +162,7 @@ class FakePyAvBuffer:
def add_stream(self, template=None):
"""Create an output buffer that captures packets for test to examine."""
class FakeStream:
class FakeAvOutputStream:
def __init__(self, capture_packets):
self.capture_packets = capture_packets
@ -162,11 +170,15 @@ class FakePyAvBuffer:
return
def mux(self, packet):
logging.debug("Muxed packet: %s", packet)
self.capture_packets.append(packet)
def __str__(self) -> str:
return f"FakeAvOutputStream<{template.name}>"
if template.name == AUDIO_STREAM_FORMAT:
return FakeStream(self.audio_packets)
return FakeStream(self.video_packets)
return FakeAvOutputStream(self.audio_packets)
return FakeAvOutputStream(self.video_packets)
def mux(self, packet):
"""Capture a packet for tests to examine."""
@ -217,7 +229,7 @@ async def async_decode_stream(hass, packets, py_av=None):
if not py_av:
py_av = MockPyAv()
py_av.container.packets = packets
py_av.container.packets = iter(packets) # Can't be rewound
with patch("av.open", new=py_av.open), patch(
"homeassistant.components.stream.core.StreamOutput.put",
@ -273,7 +285,7 @@ async def test_skip_out_of_order_packet(hass):
assert not packets[out_of_order_index].is_keyframe
packets[out_of_order_index].dts = -9090
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers
@ -309,7 +321,7 @@ async def test_discard_old_packets(hass):
# Packets after this one are considered out of order
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments
@ -331,7 +343,7 @@ async def test_packet_overflow(hass):
# Packet is so far out of order, exceeds max gap and looks like overflow
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments
@ -355,7 +367,7 @@ async def test_skip_initial_bad_packets(hass):
for i in range(0, num_bad_packets):
packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers
@ -385,7 +397,7 @@ async def test_too_many_initial_bad_packets_fails(hass):
for i in range(0, num_bad_packets):
packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
assert len(segments) == 0
assert len(decoded_stream.video_packets) == 0
@ -405,7 +417,7 @@ async def test_skip_missing_dts(hass):
continue
packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers
@ -426,7 +438,7 @@ async def test_too_many_bad_packets(hass):
for i in range(bad_packet_start, bad_packet_start + num_bad_packets):
packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == bad_packet_start
@ -454,7 +466,7 @@ async def test_audio_packets_not_found(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = PacketSequence(num_packets) # Contains only video packets
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets
@ -474,8 +486,10 @@ async def test_adts_aac_audio(hass):
packets[1][0] = 255
packets[1][1] = 241
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
assert len(decoded_stream.audio_packets) == 0
# All decoded video packets are still preserved
assert len(decoded_stream.video_packets) == num_packets - 1
async def test_audio_is_first_packet(hass):
@ -493,7 +507,7 @@ async def test_audio_is_first_packet(hass):
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments
# The audio packets are segmented with the video packets
assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
@ -511,7 +525,7 @@ async def test_audio_packets_found(hass):
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments
# The audio packet above is buffered with the video packet
assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
@ -529,7 +543,7 @@ async def test_pts_out_of_order(hass):
packets[i].pts = packets[i - 1].pts - 1
packets[i].is_keyframe = False
decoded_stream = await async_decode_stream(hass, iter(packets))
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments