From d1a483880295901c701e84f14c9f0886cbccf266 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 18 Sep 2024 23:05:09 -0500 Subject: [PATCH] Allow one reusable proxy URL per ESPHome device (#125845) * Allow one reusable URL per device * Move process to convert info * Stop previous process * Change to 404 * Better error handling --- .../components/esphome/ffmpeg_proxy.py | 102 ++++++++------ tests/components/esphome/test_ffmpeg_proxy.py | 129 +++++++++++++++++- 2 files changed, 183 insertions(+), 48 deletions(-) diff --git a/homeassistant/components/esphome/ffmpeg_proxy.py b/homeassistant/components/esphome/ffmpeg_proxy.py index 1649c628be9..c2bf72c40e5 100644 --- a/homeassistant/components/esphome/ffmpeg_proxy.py +++ b/homeassistant/components/esphome/ffmpeg_proxy.py @@ -1,7 +1,6 @@ """HTTP view that converts audio from a URL to a preferred format.""" import asyncio -from collections import defaultdict from dataclasses import dataclass, field from http import HTTPStatus import logging @@ -28,7 +27,7 @@ def async_create_proxy_url( channels: int | None = None, width: int | None = None, ) -> str: - """Create a one-time use proxy URL that automatically converts the media.""" + """Create a use proxy URL that automatically converts the media.""" data: FFmpegProxyData = hass.data[DATA_FFMPEG_PROXY] return data.async_create_proxy_url( device_id, media_url, media_format, rate, channels, width @@ -39,7 +38,10 @@ def async_create_proxy_url( class FFmpegConversionInfo: """Information for ffmpeg conversion.""" - url: str + convert_id: str + """Unique id for media conversion.""" + + media_url: str """Source URL of media to convert.""" media_format: str @@ -54,18 +56,16 @@ class FFmpegConversionInfo: width: int | None """Target sample width in bytes (None to keep source width).""" + proc: asyncio.subprocess.Process | None = None + """Subprocess doing ffmpeg conversion.""" + @dataclass class FFmpegProxyData: """Data for ffmpeg proxy conversion.""" - # device_id -> convert_id -> info - conversions: dict[str, dict[str, FFmpegConversionInfo]] = field( - default_factory=lambda: defaultdict(dict) - ) - - # device_id -> process - processes: dict[str, asyncio.subprocess.Process] = field(default_factory=dict) + # device_id -> info + conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict) def async_create_proxy_url( self, @@ -77,9 +77,19 @@ class FFmpegProxyData: width: int | None, ) -> str: """Create a one-time use proxy URL that automatically converts the media.""" + if (convert_info := self.conversions.pop(device_id, None)) is not None: + # Stop existing conversion before overwriting info + if (convert_info.proc is not None) and ( + convert_info.proc.returncode is None + ): + _LOGGER.debug( + "Stopping existing ffmpeg process for device: %s", device_id + ) + convert_info.proc.kill() + convert_id = secrets.token_urlsafe(16) - self.conversions[device_id][convert_id] = FFmpegConversionInfo( - media_url, media_format, rate, channels, width + self.conversions[device_id] = FFmpegConversionInfo( + convert_id, media_url, media_format, rate, channels, width ) _LOGGER.debug("Media URL allowed by proxy: %s", media_url) @@ -128,7 +138,7 @@ class FFmpegConvertResponse(web.StreamResponse): command_args = [ "-i", - self.convert_info.url, + self.convert_info.media_url, "-f", self.convert_info.media_format, ] @@ -156,12 +166,12 @@ class FFmpegConvertResponse(web.StreamResponse): stderr=asyncio.subprocess.PIPE, ) + # Only one conversion process per device is allowed + self.convert_info.proc = proc + assert proc.stdout is not None assert proc.stderr is not None - # Only one conversion process per device is allowed - self.proxy_data.processes[self.device_id] = proc - try: # Pull audio chunks from ffmpeg and pass them to the HTTP client while ( @@ -173,22 +183,26 @@ class FFmpegConvertResponse(web.StreamResponse): ): await writer.write(chunk) await writer.drain() + except asyncio.CancelledError: + raise # don't log error + except: + _LOGGER.exception("Unexpected error during ffmpeg conversion") + + # Process did not exit successfully + stderr_text = "" + while line := await proc.stderr.readline(): + stderr_text += line.decode() + _LOGGER.error("FFmpeg output: %s", stderr_text) + + raise finally: + # Terminate hangs, so kill is used + if proc.returncode is None: + proc.kill() + # Close connection await writer.write_eof() - # Terminate hangs, so kill is used - proc.kill() - - if proc.returncode != 0: - # Process did not exit successfully - stderr_text = "" - while line := await proc.stderr.readline(): - stderr_text += line.decode() - _LOGGER.error("Error shutting down ffmpeg: %s", stderr_text) - else: - _LOGGER.debug("Conversion completed: %s", self.convert_info) - return writer @@ -208,27 +222,25 @@ class FFmpegProxyView(HomeAssistantView): self, request: web.Request, device_id: str, filename: str ) -> web.StreamResponse: """Start a get request.""" - - # {id}.mp3 -> id - convert_id = filename.rsplit(".")[0] - - try: - convert_info = self.proxy_data.conversions[device_id].pop(convert_id) - except KeyError: - _LOGGER.error( - "Unrecognized convert id %s for device: %s", convert_id, device_id - ) + if (convert_info := self.proxy_data.conversions.get(device_id)) is None: return web.Response( - body="Convert id not recognized", status=HTTPStatus.BAD_REQUEST + body="No proxy URL for device", status=HTTPStatus.NOT_FOUND ) - # Stop any existing process - proc = self.proxy_data.processes.pop(device_id, None) - if (proc is not None) and (proc.returncode is None): - _LOGGER.debug("Stopping existing ffmpeg process for device: %s", device_id) + # {id}.mp3 -> id, mp3 + convert_id, media_format = filename.rsplit(".") - # Terminate hangs, so kill is used - proc.kill() + if (convert_info.convert_id != convert_id) or ( + convert_info.media_format != media_format + ): + return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST) + + # Stop previous process if the URL is being reused. + # We could continue from where the previous connection left off, but + # there would be no media header. + if (convert_info.proc is not None) and (convert_info.proc.returncode is None): + convert_info.proc.kill() + convert_info.proc = None # Stream converted audio back to client return FFmpegConvertResponse( diff --git a/tests/components/esphome/test_ffmpeg_proxy.py b/tests/components/esphome/test_ffmpeg_proxy.py index 577126201df..ef657ed8c7b 100644 --- a/tests/components/esphome/test_ffmpeg_proxy.py +++ b/tests/components/esphome/test_ffmpeg_proxy.py @@ -61,7 +61,7 @@ async def test_proxy_view( # Should fail because we haven't allowed the URL yet req = await client.get(url) - assert req.status == HTTPStatus.BAD_REQUEST + assert req.status == HTTPStatus.NOT_FOUND # Allow the URL with patch( @@ -75,6 +75,12 @@ async def test_proxy_view( == url ) + # Requesting the wrong media format should fail + wrong_url = f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.flac" + req = await client.get(wrong_url) + assert req.status == HTTPStatus.BAD_REQUEST + + # Correct URL req = await client.get(url) assert req.status == HTTPStatus.OK @@ -90,11 +96,11 @@ async def test_proxy_view( assert round(mp3_file.info.length, 0) == 1 -async def test_ffmpeg_error( +async def test_ffmpeg_file_doesnt_exist( hass: HomeAssistant, hass_client: ClientSessionGenerator, ) -> None: - """Test proxy HTTP view with an ffmpeg error.""" + """Test ffmpeg conversion with a file that doesn't exist.""" device_id = "1234" await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) @@ -109,3 +115,120 @@ async def test_ffmpeg_error( assert req.status == HTTPStatus.OK mp3_data = await req.content.read() assert not mp3_data + + +async def test_lingering_process( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, +) -> None: + """Test that a new request stops the old ffmpeg process.""" + device_id = "1234" + + await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) + client = await hass_client() + + with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file: + with wave.open(temp_file.name, "wb") as wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(16000 * 2)) # 1s + + temp_file.seek(0) + wav_url = pathname2url(temp_file.name) + url1 = async_create_proxy_url( + hass, + device_id, + wav_url, + media_format="wav", + rate=22050, + channels=2, + width=2, + ) + + # First request will start ffmpeg + req1 = await client.get(url1) + assert req1.status == HTTPStatus.OK + + # Only read part of the data + await req1.content.readexactly(100) + + # Allow another URL + url2 = async_create_proxy_url( + hass, + device_id, + wav_url, + media_format="wav", + rate=22050, + channels=2, + width=2, + ) + + req2 = await client.get(url2) + assert req2.status == HTTPStatus.OK + + wav_data = await req2.content.read() + + # All of the data should be there because this is a new ffmpeg process + with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file: + # We can't use getnframes() here because the WAV header will be incorrect. + # WAV encoders usually go back and update the WAV header after all of + # the frames are written, but ffmpeg can't do that because we're + # streaming the data. + # So instead, we just read and count frames until we run out. + num_frames = 0 + while chunk := wav_file.readframes(1024): + num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples + + assert num_frames == 22050 # 1s + + +async def test_request_same_url_multiple_times( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, +) -> None: + """Test that the ffmpeg process is restarted if the same URL is requested multiple times.""" + device_id = "1234" + + await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}}) + client = await hass_client() + + with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_file: + with wave.open(temp_file.name, "wb") as wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s + + temp_file.seek(0) + wav_url = pathname2url(temp_file.name) + url = async_create_proxy_url( + hass, + device_id, + wav_url, + media_format="wav", + rate=22050, + channels=2, + width=2, + ) + + # First request will start ffmpeg + req1 = await client.get(url) + assert req1.status == HTTPStatus.OK + + # Only read part of the data + await req1.content.readexactly(100) + + # Second request should restart ffmpeg + req2 = await client.get(url) + assert req2.status == HTTPStatus.OK + + wav_data = await req2.content.read() + + # All of the data should be there because this is a new ffmpeg process + with io.BytesIO(wav_data) as wav_io, wave.open(wav_io, "rb") as wav_file: + num_frames = 0 + while chunk := wav_file.readframes(1024): + num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples + + assert num_frames == 22050 * 10 # 10s