Speed up stream tests by 40-50% with shared data (#62300)

This commit is contained in:
Allen Porter 2021-12-18 23:14:21 -08:00 committed by GitHub
parent a6b680cd32
commit 832184bacd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 88 deletions

View File

@ -1,6 +1,7 @@
"""Collection of test helpers.""" """Collection of test helpers."""
from datetime import datetime from datetime import datetime
from fractions import Fraction from fractions import Fraction
import functools
from functools import partial from functools import partial
import io import io
@ -23,6 +24,11 @@ DefaultSegment = partial(
AUDIO_SAMPLE_RATE = 8000 AUDIO_SAMPLE_RATE = 8000
def stream_teardown():
"""Perform test teardown."""
frame_image_data.cache_clear()
def generate_audio_frame(pcm_mulaw=False): def generate_audio_frame(pcm_mulaw=False):
"""Generate a blank audio frame.""" """Generate a blank audio frame."""
if pcm_mulaw: if pcm_mulaw:
@ -37,6 +43,19 @@ def generate_audio_frame(pcm_mulaw=False):
return audio_frame return audio_frame
@functools.lru_cache(maxsize=1024)
def frame_image_data(frame_i, total_frames):
"""Generate image content for a frame of a video."""
img = np.empty((480, 320, 3))
img[:, :, 0] = 0.5 + 0.5 * np.sin(2 * np.pi * (0 / 3 + frame_i / total_frames))
img[:, :, 1] = 0.5 + 0.5 * np.sin(2 * np.pi * (1 / 3 + frame_i / total_frames))
img[:, :, 2] = 0.5 + 0.5 * np.sin(2 * np.pi * (2 / 3 + frame_i / total_frames))
img = np.round(255 * img).astype(np.uint8)
img = np.clip(img, 0, 255)
return img
def generate_video(encoder, container_format, duration): def generate_video(encoder, container_format, duration):
""" """
Generate a test video. Generate a test video.
@ -58,15 +77,7 @@ def generate_video(encoder, container_format, duration):
stream.options.update({"g": str(fps), "keyint_min": str(fps)}) stream.options.update({"g": str(fps), "keyint_min": str(fps)})
for frame_i in range(total_frames): for frame_i in range(total_frames):
img = frame_image_data(frame_i, total_frames)
img = np.empty((480, 320, 3))
img[:, :, 0] = 0.5 + 0.5 * np.sin(2 * np.pi * (0 / 3 + frame_i / total_frames))
img[:, :, 1] = 0.5 + 0.5 * np.sin(2 * np.pi * (1 / 3 + frame_i / total_frames))
img[:, :, 2] = 0.5 + 0.5 * np.sin(2 * np.pi * (2 / 3 + frame_i / total_frames))
img = np.round(255 * img).astype(np.uint8)
img = np.clip(img, 0, 255)
frame = av.VideoFrame.from_ndarray(img, format="rgb24") frame = av.VideoFrame.from_ndarray(img, format="rgb24")
for packet in stream.encode(frame): for packet in stream.encode(frame):
container.mux(packet) container.mux(packet)

View File

@ -25,6 +25,8 @@ import pytest
from homeassistant.components.stream.core import Segment, StreamOutput from homeassistant.components.stream.core import Segment, StreamOutput
from homeassistant.components.stream.worker import StreamState from homeassistant.components.stream.worker import StreamState
from .common import generate_h264_video, stream_teardown
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
@ -215,3 +217,16 @@ def hls_sync():
side_effect=sync.response, side_effect=sync.response,
): ):
yield sync yield sync
@pytest.fixture(scope="package")
def h264_video():
"""Generate a video, shared across tests."""
return generate_h264_video()
@pytest.fixture(scope="package", autouse=True)
def fixture_teardown():
"""Destroy package level test state."""
yield
stream_teardown()

View File

@ -20,11 +20,7 @@ from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
from tests.components.stream.common import ( from tests.components.stream.common import FAKE_TIME, DefaultSegment as Segment
FAKE_TIME,
DefaultSegment as Segment,
generate_h264_video,
)
STREAM_SOURCE = "some-stream-source" STREAM_SOURCE = "some-stream-source"
INIT_BYTES = b"init" INIT_BYTES = b"init"
@ -118,7 +114,7 @@ def make_playlist(
return "\n".join(response) return "\n".join(response)
async def test_hls_stream(hass, hls_stream, stream_worker_sync): async def test_hls_stream(hass, hls_stream, stream_worker_sync, h264_video):
""" """
Test hls stream. Test hls stream.
@ -130,8 +126,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
stream_worker_sync.pause() stream_worker_sync.pause()
# Setup demo HLS track # Setup demo HLS track
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -169,15 +164,14 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
assert fail_response.status == HTTPStatus.NOT_FOUND assert fail_response.status == HTTPStatus.NOT_FOUND
async def test_stream_timeout(hass, hass_client, stream_worker_sync): async def test_stream_timeout(hass, hass_client, stream_worker_sync, h264_video):
"""Test hls stream timeout.""" """Test hls stream timeout."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream_worker_sync.pause() stream_worker_sync.pause()
# Setup demo HLS track # Setup demo HLS track
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -211,15 +205,16 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
assert fail_response.status == HTTPStatus.NOT_FOUND assert fail_response.status == HTTPStatus.NOT_FOUND
async def test_stream_timeout_after_stop(hass, hass_client, stream_worker_sync): async def test_stream_timeout_after_stop(
hass, hass_client, stream_worker_sync, h264_video
):
"""Test hls stream timeout after the stream has been stopped already.""" """Test hls stream timeout after the stream has been stopped already."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
stream_worker_sync.pause() stream_worker_sync.pause()
# Setup demo HLS track # Setup demo HLS track
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)

View File

@ -16,17 +16,14 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .common import DefaultSegment as Segment, generate_h264_video, remux_with_audio
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
from tests.components.stream.common import (
DefaultSegment as Segment,
generate_h264_video,
remux_with_audio,
)
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
async def test_record_stream(hass, hass_client, record_worker_sync): async def test_record_stream(hass, hass_client, record_worker_sync, h264_video):
""" """
Test record stream. Test record stream.
@ -37,8 +34,7 @@ async def test_record_stream(hass, hass_client, record_worker_sync):
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
# Setup demo track # Setup demo track
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path") await stream.async_record("/example/path")
@ -54,13 +50,12 @@ async def test_record_stream(hass, hass_client, record_worker_sync):
async def test_record_lookback( async def test_record_lookback(
hass, hass_client, stream_worker_sync, record_worker_sync hass, hass_client, stream_worker_sync, record_worker_sync, h264_video
): ):
"""Exercise record with loopback.""" """Exercise record with loopback."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
# Start an HLS feed to enable lookback # Start an HLS feed to enable lookback
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
@ -74,7 +69,7 @@ async def test_record_lookback(
stream.stop() stream.stop()
async def test_recorder_timeout(hass, hass_client, stream_worker_sync): async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_video):
""" """
Test recorder timeout. Test recorder timeout.
@ -87,9 +82,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
with patch("homeassistant.components.stream.IdleTimer.fire") as mock_timeout: with patch("homeassistant.components.stream.IdleTimer.fire") as mock_timeout:
# Setup demo track # Setup demo track
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path") await stream.async_record("/example/path")
recorder = stream.add_provider(RECORDER_PROVIDER) recorder = stream.add_provider(RECORDER_PROVIDER)
@ -109,13 +102,11 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_record_path_not_allowed(hass, hass_client): async def test_record_path_not_allowed(hass, hass_client, h264_video):
"""Test where the output path is not allowed by home assistant configuration.""" """Test where the output path is not allowed by home assistant configuration."""
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
# Setup demo track stream = create_stream(hass, h264_video, {})
source = generate_h264_video()
stream = create_stream(hass, source, {})
with patch.object( with patch.object(
hass.config, "is_allowed_path", return_value=False hass.config, "is_allowed_path", return_value=False
), pytest.raises(HomeAssistantError): ), pytest.raises(HomeAssistantError):
@ -136,15 +127,14 @@ def add_parts_to_segment(segment, source):
] ]
async def test_recorder_save(tmpdir): async def test_recorder_save(tmpdir, h264_video):
"""Test recorder save.""" """Test recorder save."""
# Setup # Setup
source = generate_h264_video()
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
# Run # Run
segment = Segment(sequence=1) segment = Segment(sequence=1)
add_parts_to_segment(segment, source) add_parts_to_segment(segment, h264_video)
segment.duration = 4 segment.duration = 4
recorder_save_worker(filename, [segment]) recorder_save_worker(filename, [segment])
@ -152,18 +142,17 @@ async def test_recorder_save(tmpdir):
assert os.path.exists(filename) assert os.path.exists(filename)
async def test_recorder_discontinuity(tmpdir): async def test_recorder_discontinuity(tmpdir, h264_video):
"""Test recorder save across a discontinuity.""" """Test recorder save across a discontinuity."""
# Setup # Setup
source = generate_h264_video()
filename = f"{tmpdir}/test.mp4" filename = f"{tmpdir}/test.mp4"
# Run # Run
segment_1 = Segment(sequence=1, stream_id=0) segment_1 = Segment(sequence=1, stream_id=0)
add_parts_to_segment(segment_1, source) add_parts_to_segment(segment_1, h264_video)
segment_1.duration = 4 segment_1.duration = 4
segment_2 = Segment(sequence=2, stream_id=1) segment_2 = Segment(sequence=2, stream_id=1)
add_parts_to_segment(segment_2, source) add_parts_to_segment(segment_2, h264_video)
segment_2.duration = 4 segment_2.duration = 4
recorder_save_worker(filename, [segment_1, segment_2]) recorder_save_worker(filename, [segment_1, segment_2])
# Assert # Assert
@ -182,8 +171,29 @@ async def test_recorder_no_segments(tmpdir):
assert not os.path.exists(filename) assert not os.path.exists(filename)
@pytest.fixture(scope="module")
def h264_mov_video():
"""Generate a source video with no audio."""
return generate_h264_video(container_format="mov")
@pytest.mark.parametrize(
"audio_codec,expected_audio_streams",
[
("aac", 1), # aac is a valid mp4 codec
("pcm_mulaw", 0), # G.711 is not a valid mp4 codec
("empty", 0), # audio stream with no packets
(None, 0), # no audio stream
],
)
async def test_record_stream_audio( async def test_record_stream_audio(
hass, hass_client, stream_worker_sync, record_worker_sync hass,
hass_client,
stream_worker_sync,
record_worker_sync,
audio_codec,
expected_audio_streams,
h264_mov_video,
): ):
""" """
Test treatment of different audio inputs. Test treatment of different audio inputs.
@ -193,47 +203,38 @@ async def test_record_stream_audio(
""" """
await async_setup_component(hass, "stream", {"stream": {}}) await async_setup_component(hass, "stream", {"stream": {}})
# Generate source video with no audio # Remux source video with new audio
orig_source = generate_h264_video(container_format="mov") source = remux_with_audio(h264_mov_video, "mov", audio_codec) # mov can store PCM
for a_codec, expected_audio_streams in ( record_worker_sync.reset()
("aac", 1), # aac is a valid mp4 codec stream_worker_sync.pause()
("pcm_mulaw", 0), # G.711 is not a valid mp4 codec
("empty", 0), # audio stream with no packets
(None, 0), # no audio stream
):
# Remux source video with new audio
source = remux_with_audio(orig_source, "mov", a_codec) # mov can store PCM
record_worker_sync.reset() stream = create_stream(hass, source, {})
stream_worker_sync.pause() with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
recorder = stream.add_provider(RECORDER_PROVIDER)
stream = create_stream(hass, source, {}) while True:
with patch.object(hass.config, "is_allowed_path", return_value=True): await recorder.recv()
await stream.async_record("/example/path") if not (segment := recorder.last_segment):
recorder = stream.add_provider(RECORDER_PROVIDER) break
last_segment = segment
stream_worker_sync.resume()
while True: result = av.open(
await recorder.recv() BytesIO(last_segment.init + last_segment.get_data()),
if not (segment := recorder.last_segment): "r",
break format="mp4",
last_segment = segment )
stream_worker_sync.resume()
result = av.open( assert len(result.streams.audio) == expected_audio_streams
BytesIO(last_segment.init + last_segment.get_data()), result.close()
"r", stream.stop()
format="mp4", await hass.async_block_till_done()
)
assert len(result.streams.audio) == expected_audio_streams # Verify that the save worker was invoked, then block until its
result.close() # thread completes and is shutdown completely to avoid thread leaks.
stream.stop() await record_worker_sync.join()
await hass.async_block_till_done()
# Verify that the save worker was invoked, then block until its
# thread completes and is shutdown completely to avoid thread leaks.
await record_worker_sync.join()
async def test_recorder_log(hass, caplog): async def test_recorder_log(hass, caplog):

View File

@ -781,7 +781,7 @@ async def test_durations(hass, record_worker_sync):
stream.stop() stream.stop()
async def test_has_keyframe(hass, record_worker_sync): async def test_has_keyframe(hass, record_worker_sync, h264_video):
"""Test that the has_keyframe metadata matches the media.""" """Test that the has_keyframe metadata matches the media."""
await async_setup_component( await async_setup_component(
hass, hass,
@ -797,8 +797,7 @@ async def test_has_keyframe(hass, record_worker_sync):
}, },
) )
source = generate_h264_video() stream = create_stream(hass, h264_video, {})
stream = create_stream(hass, source, {})
# use record_worker_sync to grab output segments # use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True): with patch.object(hass.config, "is_allowed_path", return_value=True):