diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 98ce76cafde..31a92c62258 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -62,7 +62,7 @@ from .const import ( DOMAIN, TtsAudioType, ) -from .entity import TextToSpeechEntity +from .entity import TextToSpeechEntity, TTSAudioRequest from .helper import get_engine_instance from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy from .media_source import generate_media_source_id, media_source_id_to_kwargs @@ -795,9 +795,15 @@ class SpeechManager: message, language, options ) else: - extension, data = await engine_instance.internal_async_get_tts_audio( - message, language, options + + async def message_gen() -> AsyncGenerator[str]: + yield message + + tts_result = await engine_instance.internal_async_stream_tts_audio( + TTSAudioRequest(language, options, message_gen()) ) + extension = tts_result.extension + data = b"".join([chunk async for chunk in tts_result.data_gen]) if data is None or extension is None: raise HomeAssistantError( diff --git a/homeassistant/components/tts/entity.py b/homeassistant/components/tts/entity.py index ef65886452d..199d673398e 100644 --- a/homeassistant/components/tts/entity.py +++ b/homeassistant/components/tts/entity.py @@ -1,6 +1,7 @@ """Entity for Text-to-Speech.""" -from collections.abc import Mapping +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass from functools import partial from typing import Any, final @@ -16,6 +17,7 @@ from homeassistant.components.media_player import ( ) from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.core import callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.util import dt as dt_util @@ -31,6 +33,23 @@ CACHED_PROPERTIES_WITH_ATTR_ = { } +@dataclass +class TTSAudioRequest: + """Request to get TTS audio.""" + + language: str + options: dict[str, Any] + message_gen: AsyncGenerator[str] + + +@dataclass +class TTSAudioResponse: + """Response containing TTS audio stream.""" + + extension: str + data_gen: AsyncGenerator[bytes] + + class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_): """Represent a single TTS engine.""" @@ -128,19 +147,37 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH ) @final - async def internal_async_get_tts_audio( - self, message: str, language: str, options: dict[str, Any] - ) -> TtsAudioType: + async def internal_async_stream_tts_audio( + self, request: TTSAudioRequest + ) -> TTSAudioResponse: """Process an audio stream to TTS service. Only streaming content is allowed! """ self.__last_tts_loaded = dt_util.utcnow().isoformat() self.async_write_ha_state() - return await self.async_get_tts_audio( - message=message, language=language, options=options + return await self.async_stream_tts_audio(request) + + async def async_stream_tts_audio( + self, request: TTSAudioRequest + ) -> TTSAudioResponse: + """Generate speech from an incoming message. + + The default implementation is backwards compatible with async_get_tts_audio. + """ + message = "".join([chunk async for chunk in request.message_gen]) + extension, data = await self.async_get_tts_audio( + message, request.language, request.options ) + if extension is None or data is None: + raise HomeAssistantError(f"No TTS from {self.entity_id} for '{message}'") + + async def data_gen() -> AsyncGenerator[bytes]: + yield data + + return TTSAudioResponse(extension, data_gen()) + def get_tts_audio( self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: