Refactor decompression timestamp validation logic in stream component (#52462)

* Refactor dts validation logic into a separate function

Create a decompression timestamp validation function to move the logic out of
the worker into a separate class. This also uses the python itertools.chain
to chain together the initial packets with the remaining packets in the
container iterator, removing additional inline if statements.

* Reset dts validator when container is reset

* Fix typo in a comment

* Reuse existing dts_validator when disabling audio stream
This commit is contained in:
Allen Porter 2021-07-07 15:29:15 -07:00 committed by GitHub
parent 02d8d25d1d
commit e895b6cd42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections import deque from collections import deque
from collections.abc import Iterator, Mapping from collections.abc import Iterator, Mapping
from io import BytesIO from io import BytesIO
import itertools
import logging import logging
from threading import Event from threading import Event
from typing import Any, Callable, cast from typing import Any, Callable, cast
@ -201,7 +202,41 @@ class SegmentBuffer:
self._memory_file.close() self._memory_file.close()
def stream_worker( # noqa: C901 class TimestampValidator:
"""Validate ordering of timestamps for packets in a stream."""
def __init__(self) -> None:
"""Initialize the TimestampValidator."""
# Decompression timestamp of last packet in each stream
self._last_dts: dict[av.stream.Stream, float] = {}
# Number of consecutive missing decompression timestamps
self._missing_dts = 0
def is_valid(self, packet: av.Packet) -> float:
"""Validate the packet timestamp based on ordering within the stream."""
# Discard packets missing DTS. Terminate if too many are missing.
if packet.dts is None:
if self._missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
)
self._missing_dts += 1
return False
self._missing_dts = 0
# Discard when dts is not monotonic. Terminate if gap is too wide.
prev_dts = self._last_dts.get(packet.stream, float("-inf"))
if packet.dts <= prev_dts:
gap = packet.time_base * (prev_dts - packet.dts)
if gap > MAX_TIMESTAMP_GAP:
raise StopIteration(
f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}"
)
return False
self._last_dts[packet.stream] = packet.dts
return True
def stream_worker(
source: str, source: str,
options: dict[str, str], options: dict[str, str],
segment_buffer: SegmentBuffer, segment_buffer: SegmentBuffer,
@ -234,10 +269,6 @@ def stream_worker( # noqa: C901
# Iterator for demuxing # Iterator for demuxing
container_packets: Iterator[av.Packet] container_packets: Iterator[av.Packet]
# The decoder timestamps of the latest packet in each stream we processed
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
# Keep track of consecutive packets without a dts to detect end of stream.
missing_dts = 0
# The video dts at the beginning of the segment # The video dts at the beginning of the segment
segment_start_dts: int | None = None segment_start_dts: int | None = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them # Because of problems 1 and 2 below, we need to store the first few packets and replay them
@ -254,23 +285,17 @@ def stream_worker( # noqa: C901
Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists. Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
""" """
nonlocal segment_start_dts, audio_stream, container_packets nonlocal segment_start_dts, audio_stream, container_packets
missing_dts = 0
found_audio = False found_audio = False
try: try:
container_packets = container.demux((video_stream, audio_stream)) # Ensure packets are ordered correctly
dts_validator = TimestampValidator()
container_packets = filter(
dts_validator.is_valid, container.demux((video_stream, audio_stream))
)
first_packet: av.Packet | None = None first_packet: av.Packet | None = None
# Get to first video keyframe # Get to first video keyframe
while first_packet is None: while first_packet is None:
packet = next(container_packets) packet = next(container_packets)
if (
packet.dts is None
): # Allow MAX_MISSING_DTS packets with no dts, raise error on the next one
if missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing"
)
missing_dts += 1
continue
if packet.stream == audio_stream: if packet.stream == audio_stream:
found_audio = True found_audio = True
elif packet.is_keyframe: # video_keyframe elif packet.is_keyframe: # video_keyframe
@ -283,15 +308,6 @@ def stream_worker( # noqa: C901
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
): ):
packet = next(container_packets) packet = next(container_packets)
if (
packet.dts is None
): # Allow MAX_MISSING_DTS packet with no dts, raise error on the next one
if missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing"
)
missing_dts += 1
continue
if packet.stream == audio_stream: if packet.stream == audio_stream:
# detect ADTS AAC and disable audio # detect ADTS AAC and disable audio
if audio_stream.codec.name == "aac" and packet.size > 2: if audio_stream.codec.name == "aac" and packet.size > 2:
@ -300,7 +316,10 @@ def stream_worker( # noqa: C901
_LOGGER.warning( _LOGGER.warning(
"ADTS AAC detected - disabling audio stream" "ADTS AAC detected - disabling audio stream"
) )
container_packets = container.demux(video_stream) container_packets = filter(
dts_validator.is_valid,
container.demux(video_stream),
)
audio_stream = None audio_stream = None
continue continue
found_audio = True found_audio = True
@ -330,42 +349,16 @@ def stream_worker( # noqa: C901
assert isinstance(segment_start_dts, int) assert isinstance(segment_start_dts, int)
segment_buffer.reset(segment_start_dts) segment_buffer.reset(segment_start_dts)
# Rewind the stream and iterate over the initial set of packets again
# filtering out any packets with timestamp ordering issues.
packets = itertools.chain(initial_packets, container_packets)
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
if len(initial_packets) > 0: packet = next(packets)
packet = initial_packets.popleft()
else:
packet = next(container_packets)
if packet.dts is None:
# Allow MAX_MISSING_DTS consecutive packets without dts. Terminate the stream on the next one.
if missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
)
missing_dts += 1
continue
missing_dts = 0
except (av.AVError, StopIteration) as ex: except (av.AVError, StopIteration) as ex:
_LOGGER.error("Error demuxing stream: %s", str(ex)) _LOGGER.error("Error demuxing stream: %s", str(ex))
break break
# Discard packet if dts is not monotonic
if packet.dts <= last_dts[packet.stream]:
if (
packet.time_base * (last_dts[packet.stream] - packet.dts)
> MAX_TIMESTAMP_GAP
):
_LOGGER.warning(
"Timestamp overflow detected: last dts %s, dts = %s, resetting stream",
last_dts[packet.stream],
packet.dts,
)
break
continue
# Update last_dts processed
last_dts[packet.stream] = packet.dts
# Mux packets, and possibly write a segment to the output stream. # Mux packets, and possibly write a segment to the output stream.
# This mutates packet timestamps and stream # This mutates packet timestamps and stream
segment_buffer.mux_packet(packet) segment_buffer.mux_packet(packet)