mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 04:07:08 +00:00
Stream the TTS result from webview (#139543)
This commit is contained in:
parent
2d6068b842
commit
b43a7ff1fe
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -379,7 +380,7 @@ class ResultStream:
|
|||||||
"""Class that will stream the result when available."""
|
"""Class that will stream the result when available."""
|
||||||
|
|
||||||
# Streaming/conversion properties
|
# Streaming/conversion properties
|
||||||
url: str
|
token: str
|
||||||
extension: str
|
extension: str
|
||||||
content_type: str
|
content_type: str
|
||||||
|
|
||||||
@ -391,6 +392,11 @@ class ResultStream:
|
|||||||
|
|
||||||
_manager: SpeechManager
|
_manager: SpeechManager
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def url(self) -> str:
|
||||||
|
"""Get the URL to stream the result."""
|
||||||
|
return f"/api/tts_proxy/{self.token}"
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _result_cache_key(self) -> asyncio.Future[str]:
|
def _result_cache_key(self) -> asyncio.Future[str]:
|
||||||
"""Get the future that returns the cache key."""
|
"""Get the future that returns the cache key."""
|
||||||
@ -401,11 +407,11 @@ class ResultStream:
|
|||||||
"""Set cache key for message to be streamed."""
|
"""Set cache key for message to be streamed."""
|
||||||
self._result_cache_key.set_result(cache_key)
|
self._result_cache_key.set_result(cache_key)
|
||||||
|
|
||||||
async def async_get_result(self) -> bytes:
|
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||||
"""Get the stream of this result."""
|
"""Get the stream of this result."""
|
||||||
cache_key = await self._result_cache_key
|
cache_key = await self._result_cache_key
|
||||||
_extension, data = await self._manager.async_get_tts_audio(cache_key)
|
_extension, data = await self._manager.async_get_tts_audio(cache_key)
|
||||||
return data
|
yield data
|
||||||
|
|
||||||
|
|
||||||
def _hash_options(options: dict) -> str:
|
def _hash_options(options: dict) -> str:
|
||||||
@ -603,7 +609,7 @@ class SpeechManager:
|
|||||||
token = f"{secrets.token_urlsafe(16)}.{extension}"
|
token = f"{secrets.token_urlsafe(16)}.{extension}"
|
||||||
content, _ = mimetypes.guess_type(token)
|
content, _ = mimetypes.guess_type(token)
|
||||||
result_stream = ResultStream(
|
result_stream = ResultStream(
|
||||||
url=f"/api/tts_proxy/{token}",
|
token=token,
|
||||||
extension=extension,
|
extension=extension,
|
||||||
content_type=content or "audio/mpeg",
|
content_type=content or "audio/mpeg",
|
||||||
use_file_cache=use_file_cache,
|
use_file_cache=use_file_cache,
|
||||||
@ -1027,20 +1033,32 @@ class TextToSpeechView(HomeAssistantView):
|
|||||||
"""Initialize a tts view."""
|
"""Initialize a tts view."""
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
|
|
||||||
async def get(self, request: web.Request, token: str) -> web.Response:
|
async def get(self, request: web.Request, token: str) -> web.StreamResponse:
|
||||||
"""Start a get request."""
|
"""Start a get request."""
|
||||||
stream = self.manager.token_to_stream.get(token)
|
stream = self.manager.token_to_stream.get(token)
|
||||||
|
|
||||||
if stream is None:
|
if stream is None:
|
||||||
return web.Response(status=HTTPStatus.NOT_FOUND)
|
return web.Response(status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
response: web.StreamResponse | None = None
|
||||||
try:
|
try:
|
||||||
data = await stream.async_get_result()
|
async for data in stream.async_stream_result():
|
||||||
except HomeAssistantError as err:
|
if response is None:
|
||||||
_LOGGER.error("Error on get tts: %s", err)
|
response = web.StreamResponse()
|
||||||
|
response.content_type = stream.content_type
|
||||||
|
await response.prepare(request)
|
||||||
|
|
||||||
|
await response.write(data)
|
||||||
|
# pylint: disable=broad-except
|
||||||
|
except Exception as err: # noqa: BLE001
|
||||||
|
_LOGGER.error("Error streaming tts: %s", err)
|
||||||
|
|
||||||
|
# Empty result or exception happened
|
||||||
|
if response is None:
|
||||||
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
return web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
return web.Response(body=data, content_type=stream.content_type)
|
await response.write_eof()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user