Stream the TTS result from webview (#139543)

This commit is contained in:
Paulus Schoutsen 2025-02-28 22:01:39 +00:00 committed by GitHub
parent 2d6068b842
commit b43a7ff1fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from datetime import datetime
import hashlib
@ -379,7 +380,7 @@ class ResultStream:
"""Class that will stream the result when available."""
# Streaming/conversion properties
url: str
token: str
extension: str
content_type: str
@ -391,6 +392,11 @@ class ResultStream:
_manager: SpeechManager
@cached_property
def url(self) -> str:
"""Get the URL to stream the result."""
return f"/api/tts_proxy/{self.token}"
@cached_property
def _result_cache_key(self) -> asyncio.Future[str]:
"""Get the future that returns the cache key."""
@ -401,11 +407,11 @@ class ResultStream:
"""Set cache key for message to be streamed."""
self._result_cache_key.set_result(cache_key)
async def async_get_result(self) -> bytes:
async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result."""
cache_key = await self._result_cache_key
_extension, data = await self._manager.async_get_tts_audio(cache_key)
return data
yield data
def _hash_options(options: dict) -> str:
@ -603,7 +609,7 @@ class SpeechManager:
token = f"{secrets.token_urlsafe(16)}.{extension}"
content, _ = mimetypes.guess_type(token)
result_stream = ResultStream(
url=f"/api/tts_proxy/{token}",
token=token,
extension=extension,
content_type=content or "audio/mpeg",
use_file_cache=use_file_cache,
@ -1027,20 +1033,32 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view."""
self.manager = manager
async def get(self, request: web.Request, token: str) -> web.Response:
async def get(self, request: web.Request, token: str) -> web.StreamResponse:
"""Start a get request."""
stream = self.manager.token_to_stream.get(token)
if stream is None:
return web.Response(status=HTTPStatus.NOT_FOUND)
response: web.StreamResponse | None = None
try:
data = await stream.async_get_result()
except HomeAssistantError as err:
_LOGGER.error("Error on get tts: %s", err)
async for data in stream.async_stream_result():
if response is None:
response = web.StreamResponse()
response.content_type = stream.content_type
await response.prepare(request)
await response.write(data)
# pylint: disable=broad-except
except Exception as err: # noqa: BLE001
_LOGGER.error("Error streaming tts: %s", err)
# Empty result or exception happened
if response is None:
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.Response(body=data, content_type=stream.content_type)
await response.write_eof()
return response
@websocket_api.websocket_command(