Stream the TTS result from webview (#139543)

This commit is contained in:
Paulus Schoutsen 2025-02-28 22:01:39 +00:00 committed by GitHub
parent 2d6068b842
commit b43a7ff1fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import hashlib import hashlib
@ -379,7 +380,7 @@ class ResultStream:
"""Class that will stream the result when available.""" """Class that will stream the result when available."""
# Streaming/conversion properties # Streaming/conversion properties
url: str token: str
extension: str extension: str
content_type: str content_type: str
@ -391,6 +392,11 @@ class ResultStream:
_manager: SpeechManager _manager: SpeechManager
@cached_property
def url(self) -> str:
"""Get the URL to stream the result."""
return f"/api/tts_proxy/{self.token}"
@cached_property @cached_property
def _result_cache_key(self) -> asyncio.Future[str]: def _result_cache_key(self) -> asyncio.Future[str]:
"""Get the future that returns the cache key.""" """Get the future that returns the cache key."""
@ -401,11 +407,11 @@ class ResultStream:
"""Set cache key for message to be streamed.""" """Set cache key for message to be streamed."""
self._result_cache_key.set_result(cache_key) self._result_cache_key.set_result(cache_key)
async def async_get_result(self) -> bytes: async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result.""" """Get the stream of this result."""
cache_key = await self._result_cache_key cache_key = await self._result_cache_key
_extension, data = await self._manager.async_get_tts_audio(cache_key) _extension, data = await self._manager.async_get_tts_audio(cache_key)
return data yield data
def _hash_options(options: dict) -> str: def _hash_options(options: dict) -> str:
@ -603,7 +609,7 @@ class SpeechManager:
token = f"{secrets.token_urlsafe(16)}.{extension}" token = f"{secrets.token_urlsafe(16)}.{extension}"
content, _ = mimetypes.guess_type(token) content, _ = mimetypes.guess_type(token)
result_stream = ResultStream( result_stream = ResultStream(
url=f"/api/tts_proxy/{token}", token=token,
extension=extension, extension=extension,
content_type=content or "audio/mpeg", content_type=content or "audio/mpeg",
use_file_cache=use_file_cache, use_file_cache=use_file_cache,
@ -1027,20 +1033,32 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view.""" """Initialize a tts view."""
self.manager = manager self.manager = manager
async def get(self, request: web.Request, token: str) -> web.Response: async def get(self, request: web.Request, token: str) -> web.StreamResponse:
"""Start a get request.""" """Start a get request."""
stream = self.manager.token_to_stream.get(token) stream = self.manager.token_to_stream.get(token)
if stream is None: if stream is None:
return web.Response(status=HTTPStatus.NOT_FOUND) return web.Response(status=HTTPStatus.NOT_FOUND)
response: web.StreamResponse | None = None
try: try:
data = await stream.async_get_result() async for data in stream.async_stream_result():
except HomeAssistantError as err: if response is None:
_LOGGER.error("Error on get tts: %s", err) response = web.StreamResponse()
response.content_type = stream.content_type
await response.prepare(request)
await response.write(data)
# pylint: disable=broad-except
except Exception as err: # noqa: BLE001
_LOGGER.error("Error streaming tts: %s", err)
# Empty result or exception happened
if response is None:
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR) return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.Response(body=data, content_type=stream.content_type) await response.write_eof()
return response
@websocket_api.websocket_command( @websocket_api.websocket_command(