Refactor the logic for peeking into the start of the stream (#52699)

* Reset dts validator when container is reset

* Reuse existing dts_validator when disabling audio stream

* Refactor peek logic at the start of a stream

Add a PeekingIterator to support rewinding an iterator so that the code
for adjusting audio streams and start pts can be inlined in the worker.

* Simplification and readability improvements

* Remove unnecessary verbiage from comments and pydoc

* Address pylint errors

* Remove rewind function and just mux the first packet separately

* More cleanup after removing rewind()

* Skip check to self._buffer on every iteration
This commit is contained in:
Allen Porter 2021-07-27 08:53:42 -07:00 committed by GitHub
parent f1eb35b1a5
commit 022ba31999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 84 deletions

View File

@ -2,9 +2,8 @@
from __future__ import annotations
from collections import deque
from collections.abc import Iterator, Mapping
from collections.abc import Generator, Iterator, Mapping
from io import BytesIO
import itertools
import logging
from threading import Event
from typing import Any, Callable, cast
@ -202,6 +201,48 @@ class SegmentBuffer:
self._memory_file.close()
class PeekIterator(Iterator):
"""An Iterator that may allow multiple passes.
This may be consumed like a normal Iterator, however also supports a
peek() method that buffers consumed items from the iterator.
"""
def __init__(self, iterator: Iterator[av.Packet]) -> None:
"""Initialize PeekIterator."""
self._iterator = iterator
self._buffer: deque[av.Packet] = deque()
# A pointer to either _iterator or _buffer
self._next = self._iterator.__next__
def __iter__(self) -> Iterator:
"""Return an iterator."""
return self
def __next__(self) -> av.Packet:
"""Return and consume the next item available."""
return self._next()
def _pop_buffer(self) -> av.Packet:
"""Consume items from the buffer until exhausted."""
if self._buffer:
return self._buffer.popleft()
# The buffer is empty, so change to consume from the iterator
self._next = self._iterator.__next__
return self._next()
def peek(self) -> Generator[av.Packet, None, None]:
"""Return items without consuming from the iterator."""
# Items consumed are added to a buffer for future calls to __next__
# or peek. First iterate over the buffer from previous calls to peek.
self._next = self._pop_buffer
for packet in self._buffer:
yield packet
for packet in self._iterator:
self._buffer.append(packet)
yield packet
class TimestampValidator:
"""Validate ordering of timestamps for packets in a stream."""
@ -236,6 +277,31 @@ class TimestampValidator:
return True
def is_keyframe(packet: av.Packet) -> Any:
"""Return true if the packet is a keyframe."""
return packet.is_keyframe
def unsupported_audio(packets: Iterator[av.Packet], audio_stream: Any) -> bool:
"""Detect ADTS AAC, which is not supported by pyav."""
if not audio_stream:
return False
for count, packet in enumerate(packets):
if count >= PACKETS_TO_WAIT_FOR_AUDIO:
# Some streams declare an audio stream and never send any packets
_LOGGER.warning("Audio stream not found")
break
if packet.stream == audio_stream:
# detect ADTS AAC and disable audio
if audio_stream.codec.name == "aac" and packet.size > 2:
with memoryview(packet) as packet_view:
if packet_view[0] == 0xFF and packet_view[1] & 0xF0 == 0xF0:
_LOGGER.warning("ADTS AAC detected - disabling audio stream")
return True
break
return False
def stream_worker(
source: str,
options: dict[str, str],
@ -267,100 +333,55 @@ def stream_worker(
if audio_stream and audio_stream.profile is None:
audio_stream = None
# Iterator for demuxing
container_packets: Iterator[av.Packet]
# The video dts at the beginning of the segment
segment_start_dts: int | None = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
initial_packets: deque[av.Packet] = deque()
dts_validator = TimestampValidator()
container_packets = PeekIterator(
filter(dts_validator.is_valid, container.demux((video_stream, audio_stream)))
)
def is_video(packet: av.Packet) -> Any:
"""Return true if the packet is for the video stream."""
return packet.stream == video_stream
# Have to work around two problems with RTSP feeds in ffmpeg
# 1 - first frame has bad pts/dts https://trac.ffmpeg.org/ticket/5018
# 2 - seeking can be problematic https://trac.ffmpeg.org/ticket/7815
def peek_first_dts() -> bool:
"""Initialize by peeking into the first few packets of the stream.
Deal with problem #1 above (bad first packet pts/dts) by recalculating using pts/dts from second packet.
Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
"""
nonlocal segment_start_dts, audio_stream, container_packets
found_audio = False
try:
# Ensure packets are ordered correctly
dts_validator = TimestampValidator()
container_packets = filter(
dts_validator.is_valid, container.demux((video_stream, audio_stream))
#
# Use a peeking iterator to peek into the start of the stream, ensuring
# everything looks good, then go back to the start when muxing below.
try:
if audio_stream and unsupported_audio(container_packets.peek(), audio_stream):
audio_stream = None
container_packets = PeekIterator(
filter(dts_validator.is_valid, container.demux(video_stream))
)
first_packet: av.Packet | None = None
# Get to first video keyframe
while first_packet is None:
packet = next(container_packets)
if packet.stream == audio_stream:
found_audio = True
elif packet.is_keyframe: # video_keyframe
first_packet = packet
initial_packets.append(packet)
# Get first_dts from subsequent frame to first keyframe
while segment_start_dts is None or (
audio_stream
and not found_audio
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
):
packet = next(container_packets)
if packet.stream == audio_stream:
# detect ADTS AAC and disable audio
if audio_stream.codec.name == "aac" and packet.size > 2:
with memoryview(packet) as packet_view:
if packet_view[0] == 0xFF and packet_view[1] & 0xF0 == 0xF0:
_LOGGER.warning(
"ADTS AAC detected - disabling audio stream"
)
container_packets = filter(
dts_validator.is_valid,
container.demux(video_stream),
)
audio_stream = None
continue
found_audio = True
elif (
segment_start_dts is None
): # This is the second video frame to calculate first_dts from
segment_start_dts = packet.dts - packet.duration
first_packet.pts = first_packet.dts = segment_start_dts
initial_packets.append(packet)
if audio_stream and not found_audio:
_LOGGER.warning(
"Audio stream not found"
) # Some streams declare an audio stream and never send any packets
except (av.AVError, StopIteration) as ex:
_LOGGER.error(
"Error demuxing stream while finding first packet: %s", str(ex)
)
return False
return True
if not peek_first_dts():
# Advance to the first keyframe for muxing, then rewind so the muxing
# loop below can consume.
first_keyframe = next(filter(is_keyframe, filter(is_video, container_packets)))
# Deal with problem #1 above (bad first packet pts/dts) by recalculating
# using pts/dts from second packet. Use the peek iterator to advance
# without consuming from container_packets. Skip over the first keyframe
# then use the duration from the second video packet to adjust dts.
next_video_packet = next(filter(is_video, container_packets.peek()))
start_dts = next_video_packet.dts - next_video_packet.duration
first_keyframe.dts = first_keyframe.pts = start_dts
except (av.AVError, StopIteration) as ex:
_LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex))
container.close()
return
segment_buffer.set_streams(video_stream, audio_stream)
assert isinstance(segment_start_dts, int)
segment_buffer.reset(segment_start_dts)
segment_buffer.reset(start_dts)
# Mux the first keyframe, then proceed through the rest of the packets
segment_buffer.mux_packet(first_keyframe)
# Rewind the stream and iterate over the initial set of packets again
# filtering out any packets with timestamp ordering issues.
packets = itertools.chain(initial_packets, container_packets)
while not quit_event.is_set():
try:
packet = next(packets)
packet = next(container_packets)
except (av.AVError, StopIteration) as ex:
_LOGGER.error("Error demuxing stream: %s", str(ex))
break
# Mux packets, and possibly write a segment to the output stream.
# This mutates packet timestamps and stream
segment_buffer.mux_packet(packet)
# Close stream

View File

@ -606,10 +606,10 @@ async def test_update_stream_source(hass):
nonlocal last_stream_source
if not isinstance(stream_source, io.BytesIO):
last_stream_source = stream_source
# Let test know the thread is running
worker_open.set()
# Block worker thread until test wakes up
worker_wake.wait()
# Let test know the thread is running
worker_open.set()
# Block worker thread until test wakes up
worker_wake.wait()
return py_av.open(stream_source, args, kwargs)
with patch("av.open", new=blocking_open), patch(