Deduplicate wav creation in esphome ffmpeg_proxy tests (#129484)

This commit is contained in:
Erik Montnemery 2024-10-30 10:35:19 +01:00 committed by GitHub
parent 2aed01b530
commit 79d73c28a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
"""Tests for ffmpeg proxy view.""" """Tests for ffmpeg proxy view."""
from collections.abc import Generator
from http import HTTPStatus from http import HTTPStatus
import io import io
import os import os
@ -9,6 +10,7 @@ from urllib.request import pathname2url
import wave import wave
import mutagen import mutagen
import pytest
from homeassistant.components import esphome from homeassistant.components import esphome
from homeassistant.components.esphome.ffmpeg_proxy import async_create_proxy_url from homeassistant.components.esphome.ffmpeg_proxy import async_create_proxy_url
@ -18,6 +20,29 @@ from homeassistant.setup import async_setup_component
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@pytest.fixture(name="wav_file_length")
def wav_file_length_fixture() -> int:
"""Wanted length of temporary wave file."""
return 1
@pytest.fixture(name="wav_file")
def wav_file_fixture(wav_file_length: int) -> Generator[str]:
"""Create a temporary file and fill it with 1s of silence."""
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file:
_write_silence(temp_file.name, wav_file_length)
yield temp_file.name
def _write_silence(filename: str, length: int) -> None:
"""Write silence to a file."""
with wave.open(filename, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * length)) # length s
async def test_async_create_proxy_url(hass: HomeAssistant) -> None: async def test_async_create_proxy_url(hass: HomeAssistant) -> None:
"""Test that async_create_proxy_url returns the correct format.""" """Test that async_create_proxy_url returns the correct format."""
assert await async_setup_component(hass, "esphome", {}) assert await async_setup_component(hass, "esphome", {})
@ -41,6 +66,7 @@ async def test_async_create_proxy_url(hass: HomeAssistant) -> None:
async def test_proxy_view( async def test_proxy_view(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
wav_file: str,
) -> None: ) -> None:
"""Test proxy HTTP view for converting audio.""" """Test proxy HTTP view for converting audio."""
device_id = "1234" device_id = "1234"
@ -48,14 +74,7 @@ async def test_proxy_view(
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client() client = await hass_client()
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file: wav_url = pathname2url(wav_file)
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2)) # 1s
wav_url = pathname2url(temp_file.name)
convert_id = "test-id" convert_id = "test-id"
url = f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.mp3" url = f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.mp3"
@ -120,6 +139,7 @@ async def test_ffmpeg_file_doesnt_exist(
async def test_lingering_process( async def test_lingering_process(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
wav_file: str,
) -> None: ) -> None:
"""Test that a new request stops the old ffmpeg process.""" """Test that a new request stops the old ffmpeg process."""
device_id = "1234" device_id = "1234"
@ -127,14 +147,7 @@ async def test_lingering_process(
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client() client = await hass_client()
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file: wav_url = pathname2url(wav_file)
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2)) # 1s
wav_url = pathname2url(temp_file.name)
url1 = async_create_proxy_url( url1 = async_create_proxy_url(
hass, hass,
device_id, device_id,
@ -169,22 +182,24 @@ async def test_lingering_process(
wav_data = await req2.content.read() wav_data = await req2.content.read()
# All of the data should be there because this is a new ffmpeg process # All of the data should be there because this is a new ffmpeg process
with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file: with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as received_wav_file:
# We can't use getnframes() here because the WAV header will be incorrect. # We can't use getnframes() here because the WAV header will be incorrect.
# WAV encoders usually go back and update the WAV header after all of # WAV encoders usually go back and update the WAV header after all of
# the frames are written, but ffmpeg can't do that because we're # the frames are written, but ffmpeg can't do that because we're
# streaming the data. # streaming the data.
# So instead, we just read and count frames until we run out. # So instead, we just read and count frames until we run out.
num_frames = 0 num_frames = 0
while chunk := wav_file.readframes(1024): while chunk := received_wav_file.readframes(1024):
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
assert num_frames == 22050 # 1s assert num_frames == 22050 # 1s
@pytest.mark.parametrize("wav_file_length", [10])
async def test_request_same_url_multiple_times( async def test_request_same_url_multiple_times(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
wav_file: str,
) -> None: ) -> None:
"""Test that the ffmpeg process is restarted if the same URL is requested multiple times.""" """Test that the ffmpeg process is restarted if the same URL is requested multiple times."""
device_id = "1234" device_id = "1234"
@ -192,14 +207,7 @@ async def test_request_same_url_multiple_times(
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client() client = await hass_client()
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file: wav_url = pathname2url(wav_file)
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s
wav_url = pathname2url(temp_file.name)
url = async_create_proxy_url( url = async_create_proxy_url(
hass, hass,
device_id, device_id,
@ -224,9 +232,9 @@ async def test_request_same_url_multiple_times(
wav_data = await req2.content.read() wav_data = await req2.content.read()
# All of the data should be there because this is a new ffmpeg process # All of the data should be there because this is a new ffmpeg process
with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file: with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as received_wav_file:
num_frames = 0 num_frames = 0
while chunk := wav_file.readframes(1024): while chunk := received_wav_file.readframes(1024):
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
assert num_frames == 22050 * 10 # 10s assert num_frames == 22050 * 10 # 10s
@ -248,11 +256,7 @@ async def test_max_conversions_per_device(
os.path.join(temp_dir, f"{i}.wav") for i in range(max_conversions + 1) os.path.join(temp_dir, f"{i}.wav") for i in range(max_conversions + 1)
] ]
for wav_path in wav_paths: for wav_path in wav_paths:
with wave.open(wav_path, "wb") as wav_file: _write_silence(wav_path, 10)
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s
wav_urls = [pathname2url(p) for p in wav_paths] wav_urls = [pathname2url(p) for p in wav_paths]