Allow one reusable proxy URL per ESPHome device (#125845)

* Allow one reusable URL per device

* Move process to convert info

* Stop previous process

* Change to 404

* Better error handling
This commit is contained in:
Michael Hansen 2024-09-18 23:05:09 -05:00 committed by GitHub
parent f8274cd5c2
commit d1a4838802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 183 additions and 48 deletions

View File

@ -1,7 +1,6 @@
"""HTTP view that converts audio from a URL to a preferred format.""" """HTTP view that converts audio from a URL to a preferred format."""
import asyncio import asyncio
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
import logging import logging
@ -28,7 +27,7 @@ def async_create_proxy_url(
channels: int | None = None, channels: int | None = None,
width: int | None = None, width: int | None = None,
) -> str: ) -> str:
"""Create a one-time use proxy URL that automatically converts the media.""" """Create a use proxy URL that automatically converts the media."""
data: FFmpegProxyData = hass.data[DATA_FFMPEG_PROXY] data: FFmpegProxyData = hass.data[DATA_FFMPEG_PROXY]
return data.async_create_proxy_url( return data.async_create_proxy_url(
device_id, media_url, media_format, rate, channels, width device_id, media_url, media_format, rate, channels, width
@ -39,7 +38,10 @@ def async_create_proxy_url(
class FFmpegConversionInfo: class FFmpegConversionInfo:
"""Information for ffmpeg conversion.""" """Information for ffmpeg conversion."""
url: str convert_id: str
"""Unique id for media conversion."""
media_url: str
"""Source URL of media to convert.""" """Source URL of media to convert."""
media_format: str media_format: str
@ -54,18 +56,16 @@ class FFmpegConversionInfo:
width: int | None width: int | None
"""Target sample width in bytes (None to keep source width).""" """Target sample width in bytes (None to keep source width)."""
proc: asyncio.subprocess.Process | None = None
"""Subprocess doing ffmpeg conversion."""
@dataclass @dataclass
class FFmpegProxyData: class FFmpegProxyData:
"""Data for ffmpeg proxy conversion.""" """Data for ffmpeg proxy conversion."""
# device_id -> convert_id -> info # device_id -> info
conversions: dict[str, dict[str, FFmpegConversionInfo]] = field( conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict)
default_factory=lambda: defaultdict(dict)
)
# device_id -> process
processes: dict[str, asyncio.subprocess.Process] = field(default_factory=dict)
def async_create_proxy_url( def async_create_proxy_url(
self, self,
@ -77,9 +77,19 @@ class FFmpegProxyData:
width: int | None, width: int | None,
) -> str: ) -> str:
"""Create a one-time use proxy URL that automatically converts the media.""" """Create a one-time use proxy URL that automatically converts the media."""
if (convert_info := self.conversions.pop(device_id, None)) is not None:
# Stop existing conversion before overwriting info
if (convert_info.proc is not None) and (
convert_info.proc.returncode is None
):
_LOGGER.debug(
"Stopping existing ffmpeg process for device: %s", device_id
)
convert_info.proc.kill()
convert_id = secrets.token_urlsafe(16) convert_id = secrets.token_urlsafe(16)
self.conversions[device_id][convert_id] = FFmpegConversionInfo( self.conversions[device_id] = FFmpegConversionInfo(
media_url, media_format, rate, channels, width convert_id, media_url, media_format, rate, channels, width
) )
_LOGGER.debug("Media URL allowed by proxy: %s", media_url) _LOGGER.debug("Media URL allowed by proxy: %s", media_url)
@ -128,7 +138,7 @@ class FFmpegConvertResponse(web.StreamResponse):
command_args = [ command_args = [
"-i", "-i",
self.convert_info.url, self.convert_info.media_url,
"-f", "-f",
self.convert_info.media_format, self.convert_info.media_format,
] ]
@ -156,12 +166,12 @@ class FFmpegConvertResponse(web.StreamResponse):
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) )
# Only one conversion process per device is allowed
self.convert_info.proc = proc
assert proc.stdout is not None assert proc.stdout is not None
assert proc.stderr is not None assert proc.stderr is not None
# Only one conversion process per device is allowed
self.proxy_data.processes[self.device_id] = proc
try: try:
# Pull audio chunks from ffmpeg and pass them to the HTTP client # Pull audio chunks from ffmpeg and pass them to the HTTP client
while ( while (
@ -173,22 +183,26 @@ class FFmpegConvertResponse(web.StreamResponse):
): ):
await writer.write(chunk) await writer.write(chunk)
await writer.drain() await writer.drain()
except asyncio.CancelledError:
raise # don't log error
except:
_LOGGER.exception("Unexpected error during ffmpeg conversion")
# Process did not exit successfully
stderr_text = ""
while line := await proc.stderr.readline():
stderr_text += line.decode()
_LOGGER.error("FFmpeg output: %s", stderr_text)
raise
finally: finally:
# Terminate hangs, so kill is used
if proc.returncode is None:
proc.kill()
# Close connection # Close connection
await writer.write_eof() await writer.write_eof()
# Terminate hangs, so kill is used
proc.kill()
if proc.returncode != 0:
# Process did not exit successfully
stderr_text = ""
while line := await proc.stderr.readline():
stderr_text += line.decode()
_LOGGER.error("Error shutting down ffmpeg: %s", stderr_text)
else:
_LOGGER.debug("Conversion completed: %s", self.convert_info)
return writer return writer
@ -208,27 +222,25 @@ class FFmpegProxyView(HomeAssistantView):
self, request: web.Request, device_id: str, filename: str self, request: web.Request, device_id: str, filename: str
) -> web.StreamResponse: ) -> web.StreamResponse:
"""Start a get request.""" """Start a get request."""
if (convert_info := self.proxy_data.conversions.get(device_id)) is None:
# {id}.mp3 -> id
convert_id = filename.rsplit(".")[0]
try:
convert_info = self.proxy_data.conversions[device_id].pop(convert_id)
except KeyError:
_LOGGER.error(
"Unrecognized convert id %s for device: %s", convert_id, device_id
)
return web.Response( return web.Response(
body="Convert id not recognized", status=HTTPStatus.BAD_REQUEST body="No proxy URL for device", status=HTTPStatus.NOT_FOUND
) )
# Stop any existing process # {id}.mp3 -> id, mp3
proc = self.proxy_data.processes.pop(device_id, None) convert_id, media_format = filename.rsplit(".")
if (proc is not None) and (proc.returncode is None):
_LOGGER.debug("Stopping existing ffmpeg process for device: %s", device_id)
# Terminate hangs, so kill is used if (convert_info.convert_id != convert_id) or (
proc.kill() convert_info.media_format != media_format
):
return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST)
# Stop previous process if the URL is being reused.
# We could continue from where the previous connection left off, but
# there would be no media header.
if (convert_info.proc is not None) and (convert_info.proc.returncode is None):
convert_info.proc.kill()
convert_info.proc = None
# Stream converted audio back to client # Stream converted audio back to client
return FFmpegConvertResponse( return FFmpegConvertResponse(

View File

@ -61,7 +61,7 @@ async def test_proxy_view(
# Should fail because we haven't allowed the URL yet # Should fail because we haven't allowed the URL yet
req = await client.get(url) req = await client.get(url)
assert req.status == HTTPStatus.BAD_REQUEST assert req.status == HTTPStatus.NOT_FOUND
# Allow the URL # Allow the URL
with patch( with patch(
@ -75,6 +75,12 @@ async def test_proxy_view(
== url == url
) )
# Requesting the wrong media format should fail
wrong_url = f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.flac"
req = await client.get(wrong_url)
assert req.status == HTTPStatus.BAD_REQUEST
# Correct URL
req = await client.get(url) req = await client.get(url)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
@ -90,11 +96,11 @@ async def test_proxy_view(
assert round(mp3_file.info.length, 0) == 1 assert round(mp3_file.info.length, 0) == 1
async def test_ffmpeg_error( async def test_ffmpeg_file_doesnt_exist(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Test proxy HTTP view with an ffmpeg error.""" """Test ffmpeg conversion with a file that doesn't exist."""
device_id = "1234" device_id = "1234"
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
@ -109,3 +115,120 @@ async def test_ffmpeg_error(
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
mp3_data = await req.content.read() mp3_data = await req.content.read()
assert not mp3_data assert not mp3_data
async def test_lingering_process(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that a new request stops the old ffmpeg process."""
device_id = "1234"
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client()
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file:
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2)) # 1s
temp_file.seek(0)
wav_url = pathname2url(temp_file.name)
url1 = async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)
# First request will start ffmpeg
req1 = await client.get(url1)
assert req1.status == HTTPStatus.OK
# Only read part of the data
await req1.content.readexactly(100)
# Allow another URL
url2 = async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)
req2 = await client.get(url2)
assert req2.status == HTTPStatus.OK
wav_data = await req2.content.read()
# All of the data should be there because this is a new ffmpeg process
with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file:
# We can't use getnframes() here because the WAV header will be incorrect.
# WAV encoders usually go back and update the WAV header after all of
# the frames are written, but ffmpeg can't do that because we're
# streaming the data.
# So instead, we just read and count frames until we run out.
num_frames = 0
while chunk := wav_file.readframes(1024):
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
assert num_frames == 22050 # 1s
async def test_request_same_url_multiple_times(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that the ffmpeg process is restarted if the same URL is requested multiple times."""
device_id = "1234"
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client()
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file:
with wave.open(temp_file.name, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s
temp_file.seek(0)
wav_url = pathname2url(temp_file.name)
url = async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)
# First request will start ffmpeg
req1 = await client.get(url)
assert req1.status == HTTPStatus.OK
# Only read part of the data
await req1.content.readexactly(100)
# Second request should restart ffmpeg
req2 = await client.get(url)
assert req2.status == HTTPStatus.OK
wav_data = await req2.content.read()
# All of the data should be there because this is a new ffmpeg process
with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file:
num_frames = 0
while chunk := wav_file.readframes(1024):
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
assert num_frames == 22050 * 10 # 10s