Fix non monotonic dts error in stream (#53712)

* Use defaultdict for TimestampValidator._last_dts

* Combine filters

* Allow PeekIterator to be updated while preserving buffer

* Fix peek edge case

* Readd is_valid filter to video only iterator
This commit is contained in:
uvjustin 2021-07-30 23:02:33 +08:00 committed by GitHub
parent 028f6c4cac
commit 7a200a5d3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
"""Provides the worker thread needed for processing streams.""" """Provides the worker thread needed for processing streams."""
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import defaultdict, deque
from collections.abc import Generator, Iterator, Mapping from collections.abc import Generator, Iterator, Mapping
from io import BytesIO from io import BytesIO
import logging import logging
@ -222,6 +222,12 @@ class PeekIterator(Iterator):
"""Return and consume the next item available.""" """Return and consume the next item available."""
return self._next() return self._next()
def replace_underlying_iterator(self, new_iterator: Iterator) -> None:
"""Replace the underlying iterator while preserving the buffer."""
self._iterator = new_iterator
if self._next is not self._pop_buffer:
self._next = self._iterator.__next__
def _pop_buffer(self) -> av.Packet: def _pop_buffer(self) -> av.Packet:
"""Consume items from the buffer until exhausted.""" """Consume items from the buffer until exhausted."""
if self._buffer: if self._buffer:
@ -248,7 +254,9 @@ class TimestampValidator:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the TimestampValidator.""" """Initialize the TimestampValidator."""
# Decompression timestamp of last packet in each stream # Decompression timestamp of last packet in each stream
self._last_dts: dict[av.stream.Stream, float] = {} self._last_dts: dict[av.stream.Stream, int | float] = defaultdict(
lambda: float("-inf")
)
# Number of consecutive missing decompression timestamps # Number of consecutive missing decompression timestamps
self._missing_dts = 0 self._missing_dts = 0
@ -264,7 +272,7 @@ class TimestampValidator:
return False return False
self._missing_dts = 0 self._missing_dts = 0
# Discard when dts is not monotonic. Terminate if gap is too wide. # Discard when dts is not monotonic. Terminate if gap is too wide.
prev_dts = self._last_dts.get(packet.stream, float("-inf")) prev_dts = self._last_dts[packet.stream]
if packet.dts <= prev_dts: if packet.dts <= prev_dts:
gap = packet.time_base * (prev_dts - packet.dts) gap = packet.time_base * (prev_dts - packet.dts)
if gap > MAX_TIMESTAMP_GAP: if gap > MAX_TIMESTAMP_GAP:
@ -350,19 +358,25 @@ def stream_worker(
try: try:
if audio_stream and unsupported_audio(container_packets.peek(), audio_stream): if audio_stream and unsupported_audio(container_packets.peek(), audio_stream):
audio_stream = None audio_stream = None
container_packets = PeekIterator( container_packets.replace_underlying_iterator(
filter(dts_validator.is_valid, container.demux(video_stream)) filter(dts_validator.is_valid, container.demux(video_stream))
) )
# Advance to the first keyframe for muxing, then rewind so the muxing # Advance to the first keyframe for muxing, then rewind so the muxing
# loop below can consume. # loop below can consume.
first_keyframe = next(filter(is_keyframe, filter(is_video, container_packets))) first_keyframe = next(
filter(lambda pkt: is_keyframe(pkt) and is_video(pkt), container_packets)
)
# Deal with problem #1 above (bad first packet pts/dts) by recalculating # Deal with problem #1 above (bad first packet pts/dts) by recalculating
# using pts/dts from second packet. Use the peek iterator to advance # using pts/dts from second packet. Use the peek iterator to advance
# without consuming from container_packets. Skip over the first keyframe # without consuming from container_packets. Skip over the first keyframe
# then use the duration from the second video packet to adjust dts. # then use the duration from the second video packet to adjust dts.
next_video_packet = next(filter(is_video, container_packets.peek())) next_video_packet = next(filter(is_video, container_packets.peek()))
start_dts = next_video_packet.dts - next_video_packet.duration # Since the is_valid filter has already been applied before the following
# adjustment, it does not filter out the case where the duration below is
# 0 and both the first_keyframe and next_video_packet end up with the same
# dts. Use "or 1" to deal with this.
start_dts = next_video_packet.dts - (next_video_packet.duration or 1)
first_keyframe.dts = first_keyframe.pts = start_dts first_keyframe.dts = first_keyframe.pts = start_dts
except (av.AVError, StopIteration) as ex: except (av.AVError, StopIteration) as ex:
_LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex)) _LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex))