Allow a fixed number of ffmpeg proxy conversions per device (#129246)

Allow a fixed number of conversions per device
This commit is contained in:
Michael Hansen 2024-10-28 15:26:43 -05:00 committed by GitHub
parent 73f2d972e4
commit dd9ce34d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 97 additions and 10 deletions

View File

@ -1,10 +1,12 @@
"""HTTP view that converts audio from a URL to a preferred format.""" """HTTP view that converts audio from a URL to a preferred format."""
import asyncio import asyncio
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
import logging import logging
import secrets import secrets
from typing import Final
from aiohttp import web from aiohttp import web
from aiohttp.abc import AbstractStreamWriter, BaseRequest from aiohttp.abc import AbstractStreamWriter, BaseRequest
@ -17,6 +19,8 @@ from .const import DATA_FFMPEG_PROXY
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_MAX_CONVERSIONS_PER_DEVICE: Final[int] = 2
def async_create_proxy_url( def async_create_proxy_url(
hass: HomeAssistant, hass: HomeAssistant,
@ -59,13 +63,18 @@ class FFmpegConversionInfo:
proc: asyncio.subprocess.Process | None = None proc: asyncio.subprocess.Process | None = None
"""Subprocess doing ffmpeg conversion.""" """Subprocess doing ffmpeg conversion."""
is_finished: bool = False
"""True if conversion has finished."""
@dataclass @dataclass
class FFmpegProxyData: class FFmpegProxyData:
"""Data for ffmpeg proxy conversion.""" """Data for ffmpeg proxy conversion."""
# device_id -> info # device_id -> [info]
conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict) conversions: dict[str, list[FFmpegConversionInfo]] = field(
default_factory=lambda: defaultdict(list)
)
def async_create_proxy_url( def async_create_proxy_url(
self, self,
@ -77,8 +86,15 @@ class FFmpegProxyData:
width: int | None, width: int | None,
) -> str: ) -> str:
"""Create a one-time use proxy URL that automatically converts the media.""" """Create a one-time use proxy URL that automatically converts the media."""
if (convert_info := self.conversions.pop(device_id, None)) is not None:
# Stop existing conversion before overwriting info # Remove completed conversions
device_conversions = [
info for info in self.conversions[device_id] if not info.is_finished
]
while len(device_conversions) >= _MAX_CONVERSIONS_PER_DEVICE:
# Stop oldest conversion before adding a new one
convert_info = device_conversions[0]
if (convert_info.proc is not None) and ( if (convert_info.proc is not None) and (
convert_info.proc.returncode is None convert_info.proc.returncode is None
): ):
@ -87,12 +103,18 @@ class FFmpegProxyData:
) )
convert_info.proc.kill() convert_info.proc.kill()
device_conversions = device_conversions[1:]
convert_id = secrets.token_urlsafe(16) convert_id = secrets.token_urlsafe(16)
self.conversions[device_id] = FFmpegConversionInfo( device_conversions.append(
convert_id, media_url, media_format, rate, channels, width FFmpegConversionInfo(
convert_id, media_url, media_format, rate, channels, width
)
) )
_LOGGER.debug("Media URL allowed by proxy: %s", media_url) _LOGGER.debug("Media URL allowed by proxy: %s", media_url)
self.conversions[device_id] = device_conversions
return f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.{media_format}" return f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.{media_format}"
@ -167,6 +189,7 @@ class FFmpegConvertResponse(web.StreamResponse):
*command_args, *command_args,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
close_fds=False, # use posix_spawn in CPython < 3.13
) )
# Only one conversion process per device is allowed # Only one conversion process per device is allowed
@ -198,6 +221,9 @@ class FFmpegConvertResponse(web.StreamResponse):
raise raise
finally: finally:
# Allow conversion info to be removed
self.convert_info.is_finished = True
# Terminate hangs, so kill is used # Terminate hangs, so kill is used
if proc.returncode is None: if proc.returncode is None:
proc.kill() proc.kill()
@ -224,7 +250,8 @@ class FFmpegProxyView(HomeAssistantView):
self, request: web.Request, device_id: str, filename: str self, request: web.Request, device_id: str, filename: str
) -> web.StreamResponse: ) -> web.StreamResponse:
"""Start a get request.""" """Start a get request."""
if (convert_info := self.proxy_data.conversions.get(device_id)) is None: device_conversions = self.proxy_data.conversions[device_id]
if not device_conversions:
return web.Response( return web.Response(
body="No proxy URL for device", status=HTTPStatus.NOT_FOUND body="No proxy URL for device", status=HTTPStatus.NOT_FOUND
) )
@ -232,9 +259,16 @@ class FFmpegProxyView(HomeAssistantView):
# {id}.mp3 -> id, mp3 # {id}.mp3 -> id, mp3
convert_id, media_format = filename.rsplit(".") convert_id, media_format = filename.rsplit(".")
if (convert_info.convert_id != convert_id) or ( # Look up conversion info
convert_info.media_format != media_format convert_info: FFmpegConversionInfo | None = None
): for maybe_convert_info in device_conversions:
if (maybe_convert_info.convert_id == convert_id) and (
maybe_convert_info.media_format == media_format
):
convert_info = maybe_convert_info
break
if convert_info is None:
return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST) return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST)
# Stop previous process if the URL is being reused. # Stop previous process if the URL is being reused.

View File

@ -2,6 +2,7 @@
from http import HTTPStatus from http import HTTPStatus
import io import io
import os
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
from urllib.request import pathname2url from urllib.request import pathname2url
@ -232,3 +233,55 @@ async def test_request_same_url_multiple_times(
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
assert num_frames == 22050 * 10 # 10s assert num_frames == 22050 * 10 # 10s
async def test_max_conversions_per_device(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that each device has a maximum number of conversions (currently 2)."""
max_conversions = 2
device_ids = ["1234", "5678"]
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client()
with tempfile.TemporaryDirectory() as temp_dir:
wav_paths = [
os.path.join(temp_dir, f"{i}.wav") for i in range(max_conversions + 1)
]
for wav_path in wav_paths:
with wave.open(wav_path, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s
wav_urls = [pathname2url(p) for p in wav_paths]
# Each device will have max + 1 conversions
device_urls = {
device_id: [
async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)
for wav_url in wav_urls
]
for device_id in device_ids
}
for urls in device_urls.values():
# First URL should fail because it was overwritten by the others
req = await client.get(urls[0])
assert req.status == HTTPStatus.BAD_REQUEST
# All other URLs should succeed
for url in urls[1:]:
req = await client.get(url)
assert req.status == HTTPStatus.OK