From dd9ce34d18061f2cc128097dc132c120233329fd Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 28 Oct 2024 15:26:43 -0500 Subject: [PATCH] Allow a fixed number of ffmpeg proxy conversions per device (#129246) Allow a fixed number of conversions per device --- .../components/esphome/ffmpeg_proxy.py | 54 +++++++++++++++---- tests/components/esphome/test_ffmpeg_proxy.py | 53 ++++++++++++++++++ 2 files changed, 97 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/esphome/ffmpeg_proxy.py b/homeassistant/components/esphome/ffmpeg_proxy.py index 8f24a478738..5313c67afac 100644 --- a/homeassistant/components/esphome/ffmpeg_proxy.py +++ b/homeassistant/components/esphome/ffmpeg_proxy.py @@ -1,10 +1,12 @@ """HTTP view that converts audio from a URL to a preferred format.""" import asyncio +from collections import defaultdict from dataclasses import dataclass, field from http import HTTPStatus import logging import secrets +from typing import Final from aiohttp import web from aiohttp.abc import AbstractStreamWriter, BaseRequest @@ -17,6 +19,8 @@ from .const import DATA_FFMPEG_PROXY _LOGGER = logging.getLogger(__name__) +_MAX_CONVERSIONS_PER_DEVICE: Final[int] = 2 + def async_create_proxy_url( hass: HomeAssistant, @@ -59,13 +63,18 @@ class FFmpegConversionInfo: proc: asyncio.subprocess.Process | None = None """Subprocess doing ffmpeg conversion.""" + is_finished: bool = False + """True if conversion has finished.""" + @dataclass class FFmpegProxyData: """Data for ffmpeg proxy conversion.""" - # device_id -> info - conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict) + # device_id -> [info] + conversions: dict[str, list[FFmpegConversionInfo]] = field( + default_factory=lambda: defaultdict(list) + ) def async_create_proxy_url( self, @@ -77,8 +86,15 @@ class FFmpegProxyData: width: int | None, ) -> str: """Create a one-time use proxy URL that automatically converts the media.""" - if (convert_info := self.conversions.pop(device_id, None)) is not None: - # Stop existing conversion before overwriting info + + # Remove completed conversions + device_conversions = [ + info for info in self.conversions[device_id] if not info.is_finished + ] + + while len(device_conversions) >= _MAX_CONVERSIONS_PER_DEVICE: + # Stop oldest conversion before adding a new one + convert_info = device_conversions[0] if (convert_info.proc is not None) and ( convert_info.proc.returncode is None ): @@ -87,12 +103,18 @@ class FFmpegProxyData: ) convert_info.proc.kill() + device_conversions = device_conversions[1:] + convert_id = secrets.token_urlsafe(16) - self.conversions[device_id] = FFmpegConversionInfo( - convert_id, media_url, media_format, rate, channels, width + device_conversions.append( + FFmpegConversionInfo( + convert_id, media_url, media_format, rate, channels, width + ) ) _LOGGER.debug("Media URL allowed by proxy: %s", media_url) + self.conversions[device_id] = device_conversions + return f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.{media_format}" @@ -167,6 +189,7 @@ class FFmpegConvertResponse(web.StreamResponse): *command_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + close_fds=False, # use posix_spawn in CPython < 3.13 ) # Only one conversion process per device is allowed @@ -198,6 +221,9 @@ class FFmpegConvertResponse(web.StreamResponse): raise finally: + # Allow conversion info to be removed + self.convert_info.is_finished = True + # Terminate hangs, so kill is used if proc.returncode is None: proc.kill() @@ -224,7 +250,8 @@ class FFmpegProxyView(HomeAssistantView): self, request: web.Request, device_id: str, filename: str ) -> web.StreamResponse: """Start a get request.""" - if (convert_info := self.proxy_data.conversions.get(device_id)) is None: + device_conversions = self.proxy_data.conversions[device_id] + if not device_conversions: return web.Response( body="No proxy URL for device", status=HTTPStatus.NOT_FOUND ) @@ -232,9 +259,16 @@ class FFmpegProxyView(HomeAssistantView): # {id}.mp3 -> id, mp3 convert_id, media_format = filename.rsplit(".") - if (convert_info.convert_id != convert_id) or ( - convert_info.media_format != media_format - ): + # Look up conversion info + convert_info: FFmpegConversionInfo | None = None + for maybe_convert_info in device_conversions: + if (maybe_convert_info.convert_id == convert_id) and ( + maybe_convert_info.media_format == media_format + ): + convert_info = maybe_convert_info + break + + if convert_info is None: return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST) # Stop previous process if the URL is being reused. diff --git a/tests/components/esphome/test_ffmpeg_proxy.py b/tests/components/esphome/test_ffmpeg_proxy.py index ef657ed8c7b..24650e611e0 100644 --- a/tests/components/esphome/test_ffmpeg_proxy.py +++ b/tests/components/esphome/test_ffmpeg_proxy.py @@ -2,6 +2,7 @@ from http import HTTPStatus import io +import os import tempfile from unittest.mock import patch from urllib.request import pathname2url @@ -232,3 +233,55 @@ async def test_request_same_url_multiple_times( num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples assert num_frames == 22050 * 10 # 10s + + +async def test_max_conversions_per_device( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, +) -> None: + """Test that each device has a maximum number of conversions (currently 2).""" + max_conversions = 2 + device_ids = ["1234", "5678"] + + await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) + client = await hass_client() + + with tempfile.TemporaryDirectory() as temp_dir: + wav_paths = [ + os.path.join(temp_dir, f"{i}.wav") for i in range(max_conversions + 1) + ] + for wav_path in wav_paths: + with wave.open(wav_path, "wb") as wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s + + wav_urls = [pathname2url(p) for p in wav_paths] + + # Each device will have max + 1 conversions + device_urls = { + device_id: [ + async_create_proxy_url( + hass, + device_id, + wav_url, + media_format="wav", + rate=22050, + channels=2, + width=2, + ) + for wav_url in wav_urls + ] + for device_id in device_ids + } + + for urls in device_urls.values(): + # First URL should fail because it was overwritten by the others + req = await client.get(urls[0]) + assert req.status == HTTPStatus.BAD_REQUEST + + # All other URLs should succeed + for url in urls[1:]: + req = await client.get(url) + assert req.status == HTTPStatus.OK