Use buffer at stream start with unsupported audio (#54672)

Add a test that reproduces the issue where resetting the iterator
drops the already read packets. Fix a bug in replace_underlying_iterator
because checking the self._next function turns out not to work since
it points to a bound method so the "is not" check fails.
This commit is contained in:
Allen Porter 2021-08-15 21:02:37 -07:00 committed by GitHub
parent bec42b74fe
commit 094f7d38ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 20 deletions

View File

@ -225,7 +225,7 @@ class PeekIterator(Iterator):
def replace_underlying_iterator(self, new_iterator: Iterator) -> None: def replace_underlying_iterator(self, new_iterator: Iterator) -> None:
"""Replace the underlying iterator while preserving the buffer.""" """Replace the underlying iterator while preserving the buffer."""
self._iterator = new_iterator self._iterator = new_iterator
if self._next is not self._pop_buffer: if not self._buffer:
self._next = self._iterator.__next__ self._next = self._iterator.__next__
def _pop_buffer(self) -> av.Packet: def _pop_buffer(self) -> av.Packet:

View File

@ -15,6 +15,7 @@ failure modes or corner cases like how out of order packets are handled.
import fractions import fractions
import io import io
import logging
import math import math
import threading import threading
from unittest.mock import patch from unittest.mock import patch
@ -52,7 +53,7 @@ SEGMENTS_PER_PACKET = PACKET_DURATION / SEGMENT_DURATION
TIMEOUT = 15 TIMEOUT = 15
class FakePyAvStream: class FakeAvInputStream:
"""A fake pyav Stream.""" """A fake pyav Stream."""
def __init__(self, name, rate): def __init__(self, name, rate):
@ -66,9 +67,13 @@ class FakePyAvStream:
self.codec = FakeCodec() self.codec = FakeCodec()
def __str__(self) -> str:
"""Return a stream name for debugging."""
return f"FakePyAvStream<{self.name}, {self.time_base}>"
VIDEO_STREAM = FakePyAvStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
AUDIO_STREAM = FakePyAvStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE) VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE)
class PacketSequence: class PacketSequence:
@ -110,6 +115,9 @@ class PacketSequence:
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL) is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
size = 3 size = 3
def __str__(self) -> str:
return f"FakePacket<stream={self.stream}, pts={self.pts}, key={self.is_keyframe}>"
return FakePacket() return FakePacket()
@ -154,7 +162,7 @@ class FakePyAvBuffer:
def add_stream(self, template=None): def add_stream(self, template=None):
"""Create an output buffer that captures packets for test to examine.""" """Create an output buffer that captures packets for test to examine."""
class FakeStream: class FakeAvOutputStream:
def __init__(self, capture_packets): def __init__(self, capture_packets):
self.capture_packets = capture_packets self.capture_packets = capture_packets
@ -162,11 +170,15 @@ class FakePyAvBuffer:
return return
def mux(self, packet): def mux(self, packet):
logging.debug("Muxed packet: %s", packet)
self.capture_packets.append(packet) self.capture_packets.append(packet)
def __str__(self) -> str:
return f"FakeAvOutputStream<{template.name}>"
if template.name == AUDIO_STREAM_FORMAT: if template.name == AUDIO_STREAM_FORMAT:
return FakeStream(self.audio_packets) return FakeAvOutputStream(self.audio_packets)
return FakeStream(self.video_packets) return FakeAvOutputStream(self.video_packets)
def mux(self, packet): def mux(self, packet):
"""Capture a packet for tests to examine.""" """Capture a packet for tests to examine."""
@ -217,7 +229,7 @@ async def async_decode_stream(hass, packets, py_av=None):
if not py_av: if not py_av:
py_av = MockPyAv() py_av = MockPyAv()
py_av.container.packets = packets py_av.container.packets = iter(packets) # Can't be rewound
with patch("av.open", new=py_av.open), patch( with patch("av.open", new=py_av.open), patch(
"homeassistant.components.stream.core.StreamOutput.put", "homeassistant.components.stream.core.StreamOutput.put",
@ -273,7 +285,7 @@ async def test_skip_out_of_order_packet(hass):
assert not packets[out_of_order_index].is_keyframe assert not packets[out_of_order_index].is_keyframe
packets[out_of_order_index].dts = -9090 packets[out_of_order_index].dts = -9090
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check sequence numbers # Check sequence numbers
@ -309,7 +321,7 @@ async def test_discard_old_packets(hass):
# Packets after this one are considered out of order # Packets after this one are considered out of order
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090 packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments
@ -331,7 +343,7 @@ async def test_packet_overflow(hass):
# Packet is so far out of order, exceeds max gap and looks like overflow # Packet is so far out of order, exceeds max gap and looks like overflow
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments
@ -355,7 +367,7 @@ async def test_skip_initial_bad_packets(hass):
for i in range(0, num_bad_packets): for i in range(0, num_bad_packets):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check sequence numbers # Check sequence numbers
@ -385,7 +397,7 @@ async def test_too_many_initial_bad_packets_fails(hass):
for i in range(0, num_bad_packets): for i in range(0, num_bad_packets):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == 0 assert len(segments) == 0
assert len(decoded_stream.video_packets) == 0 assert len(decoded_stream.video_packets) == 0
@ -405,7 +417,7 @@ async def test_skip_missing_dts(hass):
continue continue
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check sequence numbers # Check sequence numbers
@ -426,7 +438,7 @@ async def test_too_many_bad_packets(hass):
for i in range(bad_packet_start, bad_packet_start + num_bad_packets): for i in range(bad_packet_start, bad_packet_start + num_bad_packets):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == bad_packet_start assert len(decoded_stream.video_packets) == bad_packet_start
@ -454,7 +466,7 @@ async def test_audio_packets_not_found(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = PacketSequence(num_packets) # Contains only video packets packets = PacketSequence(num_packets) # Contains only video packets
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets assert len(decoded_stream.video_packets) == num_packets
@ -474,8 +486,10 @@ async def test_adts_aac_audio(hass):
packets[1][0] = 255 packets[1][0] = 255
packets[1][1] = 241 packets[1][1] = 241
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
assert len(decoded_stream.audio_packets) == 0 assert len(decoded_stream.audio_packets) == 0
# All decoded video packets are still preserved
assert len(decoded_stream.video_packets) == num_packets - 1
async def test_audio_is_first_packet(hass): async def test_audio_is_first_packet(hass):
@ -493,7 +507,7 @@ async def test_audio_is_first_packet(hass):
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# The audio packets are segmented with the video packets # The audio packets are segmented with the video packets
assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
@ -511,7 +525,7 @@ async def test_audio_packets_found(hass):
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# The audio packet above is buffered with the video packet # The audio packet above is buffered with the video packet
assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
@ -529,7 +543,7 @@ async def test_pts_out_of_order(hass):
packets[i].pts = packets[i - 1].pts - 1 packets[i].pts = packets[i - 1].pts - 1
packets[i].is_keyframe = False packets[i].is_keyframe = False
decoded_stream = await async_decode_stream(hass, iter(packets)) decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments