diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 72625fa0f5a..69def43b2a2 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -1,7 +1,7 @@ """Provides the worker thread needed for processing streams.""" from __future__ import annotations -from collections import deque +from collections import defaultdict, deque from collections.abc import Generator, Iterator, Mapping from io import BytesIO import logging @@ -222,6 +222,12 @@ class PeekIterator(Iterator): """Return and consume the next item available.""" return self._next() + def replace_underlying_iterator(self, new_iterator: Iterator) -> None: + """Replace the underlying iterator while preserving the buffer.""" + self._iterator = new_iterator + if self._next is not self._pop_buffer: + self._next = self._iterator.__next__ + def _pop_buffer(self) -> av.Packet: """Consume items from the buffer until exhausted.""" if self._buffer: @@ -248,7 +254,9 @@ class TimestampValidator: def __init__(self) -> None: """Initialize the TimestampValidator.""" # Decompression timestamp of last packet in each stream - self._last_dts: dict[av.stream.Stream, float] = {} + self._last_dts: dict[av.stream.Stream, int | float] = defaultdict( + lambda: float("-inf") + ) # Number of consecutive missing decompression timestamps self._missing_dts = 0 @@ -264,7 +272,7 @@ class TimestampValidator: return False self._missing_dts = 0 # Discard when dts is not monotonic. Terminate if gap is too wide. - prev_dts = self._last_dts.get(packet.stream, float("-inf")) + prev_dts = self._last_dts[packet.stream] if packet.dts <= prev_dts: gap = packet.time_base * (prev_dts - packet.dts) if gap > MAX_TIMESTAMP_GAP: @@ -350,19 +358,25 @@ def stream_worker( try: if audio_stream and unsupported_audio(container_packets.peek(), audio_stream): audio_stream = None - container_packets = PeekIterator( + container_packets.replace_underlying_iterator( filter(dts_validator.is_valid, container.demux(video_stream)) ) # Advance to the first keyframe for muxing, then rewind so the muxing # loop below can consume. - first_keyframe = next(filter(is_keyframe, filter(is_video, container_packets))) + first_keyframe = next( + filter(lambda pkt: is_keyframe(pkt) and is_video(pkt), container_packets) + ) # Deal with problem #1 above (bad first packet pts/dts) by recalculating # using pts/dts from second packet. Use the peek iterator to advance # without consuming from container_packets. Skip over the first keyframe # then use the duration from the second video packet to adjust dts. next_video_packet = next(filter(is_video, container_packets.peek())) - start_dts = next_video_packet.dts - next_video_packet.duration + # Since the is_valid filter has already been applied before the following + # adjustment, it does not filter out the case where the duration below is + # 0 and both the first_keyframe and next_video_packet end up with the same + # dts. Use "or 1" to deal with this. + start_dts = next_video_packet.dts - (next_video_packet.duration or 1) first_keyframe.dts = first_keyframe.pts = start_dts except (av.AVError, StopIteration) as ex: _LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex))