From 6d656dd2ecc5a7dcdc1d580358f8d0900f27cd27 Mon Sep 17 00:00:00 2001 From: uvjustin <46082645+uvjustin@users.noreply.github.com> Date: Wed, 16 Jun 2021 02:27:53 +0800 Subject: [PATCH] Speed up record stream audio test (#51901) --- tests/components/stream/common.py | 92 +++++++++++++++--------- tests/components/stream/test_recorder.py | 13 ++-- 2 files changed, 66 insertions(+), 39 deletions(-) diff --git a/tests/components/stream/common.py b/tests/components/stream/common.py index 4c6841d03db..a39e8bdca21 100644 --- a/tests/components/stream/common.py +++ b/tests/components/stream/common.py @@ -8,26 +8,27 @@ import numpy as np AUDIO_SAMPLE_RATE = 8000 -def generate_h264_video(container_format="mp4", audio_codec=None): +def generate_audio_frame(pcm_mulaw=False): + """Generate a blank audio frame.""" + if pcm_mulaw: + audio_frame = av.AudioFrame(format="s16", layout="mono", samples=1) + audio_bytes = b"\x00\x00" + else: + audio_frame = av.AudioFrame(format="dbl", layout="mono", samples=1024) + audio_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00" * 1024 + audio_frame.planes[0].update(audio_bytes) + audio_frame.sample_rate = AUDIO_SAMPLE_RATE + audio_frame.time_base = Fraction(1, AUDIO_SAMPLE_RATE) + return audio_frame + + +def generate_h264_video(container_format="mp4"): """ Generate a test video. See: http://docs.mikeboers.com/pyav/develop/cookbook/numpy.html """ - def generate_audio_frame(pcm_mulaw=False): - """Generate a blank audio frame.""" - if pcm_mulaw: - audio_frame = av.AudioFrame(format="s16", layout="mono", samples=1) - audio_bytes = b"\x00\x00" - else: - audio_frame = av.AudioFrame(format="dbl", layout="mono", samples=1024) - audio_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00" * 1024 - audio_frame.planes[0].update(audio_bytes) - audio_frame.sample_rate = AUDIO_SAMPLE_RATE - audio_frame.time_base = Fraction(1, AUDIO_SAMPLE_RATE) - return audio_frame - duration = 5 fps = 24 total_frames = duration * fps @@ -42,6 +43,39 @@ def generate_h264_video(container_format="mp4", audio_codec=None): stream.pix_fmt = "yuv420p" stream.options.update({"g": str(fps), "keyint_min": str(fps)}) + for frame_i in range(total_frames): + + img = np.empty((480, 320, 3)) + img[:, :, 0] = 0.5 + 0.5 * np.sin(2 * np.pi * (0 / 3 + frame_i / total_frames)) + img[:, :, 1] = 0.5 + 0.5 * np.sin(2 * np.pi * (1 / 3 + frame_i / total_frames)) + img[:, :, 2] = 0.5 + 0.5 * np.sin(2 * np.pi * (2 / 3 + frame_i / total_frames)) + + img = np.round(255 * img).astype(np.uint8) + img = np.clip(img, 0, 255) + + frame = av.VideoFrame.from_ndarray(img, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush stream + for packet in stream.encode(): + container.mux(packet) + + # Close the file + container.close() + output.seek(0) + + return output + + +def remux_with_audio(source, container_format, audio_codec): + """Remux an existing source with new audio.""" + av_source = av.open(source, mode="r") + output = io.BytesIO() + output.name = "test.mov" if container_format == "mov" else "test.mp4" + container = av.open(output, mode="w", format=container_format) + container.add_stream(template=av_source.streams.video[0]) + a_packet = None last_a_dts = -1 if audio_codec is not None: @@ -57,23 +91,17 @@ def generate_h264_video(container_format="mp4", audio_codec=None): if a_packets: a_packet = a_packets[0] - for frame_i in range(total_frames): - - img = np.empty((480, 320, 3)) - img[:, :, 0] = 0.5 + 0.5 * np.sin(2 * np.pi * (0 / 3 + frame_i / total_frames)) - img[:, :, 1] = 0.5 + 0.5 * np.sin(2 * np.pi * (1 / 3 + frame_i / total_frames)) - img[:, :, 2] = 0.5 + 0.5 * np.sin(2 * np.pi * (2 / 3 + frame_i / total_frames)) - - img = np.round(255 * img).astype(np.uint8) - img = np.clip(img, 0, 255) - - frame = av.VideoFrame.from_ndarray(img, format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) - + # open original source and iterate through video packets + for packet in av_source.demux(video=0): + if not packet.dts: + continue + container.mux(packet) if a_packet is not None: - a_packet.pts = int(frame_i / (fps * a_packet.time_base)) - while a_packet.pts * a_packet.time_base * fps < frame_i + 1: + a_packet.pts = int(packet.dts * packet.time_base / a_packet.time_base) + while ( + a_packet.pts * a_packet.time_base + < (packet.dts + packet.duration) * packet.time_base + ): a_packet.dts = a_packet.pts if ( a_packet.dts > last_a_dts @@ -82,10 +110,6 @@ def generate_h264_video(container_format="mp4", audio_codec=None): last_a_dts = a_packet.dts a_packet.pts += a_packet.duration - # Flush stream - for packet in stream.encode(): - container.mux(packet) - # Close the file container.close() output.seek(0) diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index 72c6dfa197f..31661db3886 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -17,7 +17,7 @@ from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from tests.common import async_fire_time_changed -from tests.components.stream.common import generate_h264_video +from tests.components.stream.common import generate_h264_video, remux_with_audio MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever @@ -190,19 +190,22 @@ async def test_record_stream_audio( """ await async_setup_component(hass, "stream", {"stream": {}}) + # Generate source video with no audio + source = generate_h264_video(container_format="mov") + for a_codec, expected_audio_streams in ( ("aac", 1), # aac is a valid mp4 codec ("pcm_mulaw", 0), # G.711 is not a valid mp4 codec ("empty", 0), # audio stream with no packets (None, 0), # no audio stream ): + + # Remux source video with new audio + source = remux_with_audio(source, "mov", a_codec) # mov can store PCM + record_worker_sync.reset() stream_worker_sync.pause() - # Setup demo track - source = generate_h264_video( - container_format="mov", audio_codec=a_codec - ) # mov can store PCM stream = create_stream(hass, source, {}) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path")