mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Stream the TTS result from webview (#139543)
This commit is contained in:
parent
2d6068b842
commit
b43a7ff1fe
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
@ -379,7 +380,7 @@ class ResultStream:
|
||||
"""Class that will stream the result when available."""
|
||||
|
||||
# Streaming/conversion properties
|
||||
url: str
|
||||
token: str
|
||||
extension: str
|
||||
content_type: str
|
||||
|
||||
@ -391,6 +392,11 @@ class ResultStream:
|
||||
|
||||
_manager: SpeechManager
|
||||
|
||||
@cached_property
|
||||
def url(self) -> str:
|
||||
"""Get the URL to stream the result."""
|
||||
return f"/api/tts_proxy/{self.token}"
|
||||
|
||||
@cached_property
|
||||
def _result_cache_key(self) -> asyncio.Future[str]:
|
||||
"""Get the future that returns the cache key."""
|
||||
@ -401,11 +407,11 @@ class ResultStream:
|
||||
"""Set cache key for message to be streamed."""
|
||||
self._result_cache_key.set_result(cache_key)
|
||||
|
||||
async def async_get_result(self) -> bytes:
|
||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of this result."""
|
||||
cache_key = await self._result_cache_key
|
||||
_extension, data = await self._manager.async_get_tts_audio(cache_key)
|
||||
return data
|
||||
yield data
|
||||
|
||||
|
||||
def _hash_options(options: dict) -> str:
|
||||
@ -603,7 +609,7 @@ class SpeechManager:
|
||||
token = f"{secrets.token_urlsafe(16)}.{extension}"
|
||||
content, _ = mimetypes.guess_type(token)
|
||||
result_stream = ResultStream(
|
||||
url=f"/api/tts_proxy/{token}",
|
||||
token=token,
|
||||
extension=extension,
|
||||
content_type=content or "audio/mpeg",
|
||||
use_file_cache=use_file_cache,
|
||||
@ -1027,20 +1033,32 @@ class TextToSpeechView(HomeAssistantView):
|
||||
"""Initialize a tts view."""
|
||||
self.manager = manager
|
||||
|
||||
async def get(self, request: web.Request, token: str) -> web.Response:
|
||||
async def get(self, request: web.Request, token: str) -> web.StreamResponse:
|
||||
"""Start a get request."""
|
||||
stream = self.manager.token_to_stream.get(token)
|
||||
|
||||
if stream is None:
|
||||
return web.Response(status=HTTPStatus.NOT_FOUND)
|
||||
|
||||
response: web.StreamResponse | None = None
|
||||
try:
|
||||
data = await stream.async_get_result()
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error on get tts: %s", err)
|
||||
async for data in stream.async_stream_result():
|
||||
if response is None:
|
||||
response = web.StreamResponse()
|
||||
response.content_type = stream.content_type
|
||||
await response.prepare(request)
|
||||
|
||||
await response.write(data)
|
||||
# pylint: disable=broad-except
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.error("Error streaming tts: %s", err)
|
||||
|
||||
# Empty result or exception happened
|
||||
if response is None:
|
||||
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
return web.Response(body=data, content_type=stream.content_type)
|
||||
await response.write_eof()
|
||||
return response
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
|
Loading…
x
Reference in New Issue
Block a user