Enable TTS streaming implementations (#140176)

* Enable TTS streaming implementations

* Update comment

* Revert type change
This commit is contained in:
Paulus Schoutsen 2025-03-10 15:12:37 -04:00 committed by GitHub
parent d498dbd5ac
commit 1665d9474f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 9 deletions

View File

@ -62,7 +62,7 @@ from .const import (
DOMAIN, DOMAIN,
TtsAudioType, TtsAudioType,
) )
from .entity import TextToSpeechEntity from .entity import TextToSpeechEntity, TTSAudioRequest
from .helper import get_engine_instance from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs from .media_source import generate_media_source_id, media_source_id_to_kwargs
@ -795,9 +795,15 @@ class SpeechManager:
message, language, options message, language, options
) )
else: else:
extension, data = await engine_instance.internal_async_get_tts_audio(
message, language, options async def message_gen() -> AsyncGenerator[str]:
yield message
tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_gen())
) )
extension = tts_result.extension
data = b"".join([chunk async for chunk in tts_result.data_gen])
if data is None or extension is None: if data is None or extension is None:
raise HomeAssistantError( raise HomeAssistantError(

View File

@ -1,6 +1,7 @@
"""Entity for Text-to-Speech.""" """Entity for Text-to-Speech."""
from collections.abc import Mapping from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Any, final from typing import Any, final
@ -16,6 +17,7 @@ from homeassistant.components.media_player import (
) )
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -31,6 +33,23 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
} }
@dataclass
class TTSAudioRequest:
"""Request to get TTS audio."""
language: str
options: dict[str, Any]
message_gen: AsyncGenerator[str]
@dataclass
class TTSAudioResponse:
"""Response containing TTS audio stream."""
extension: str
data_gen: AsyncGenerator[bytes]
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_): class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
"""Represent a single TTS engine.""" """Represent a single TTS engine."""
@ -128,19 +147,37 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
) )
@final @final
async def internal_async_get_tts_audio( async def internal_async_stream_tts_audio(
self, message: str, language: str, options: dict[str, Any] self, request: TTSAudioRequest
) -> TtsAudioType: ) -> TTSAudioResponse:
"""Process an audio stream to TTS service. """Process an audio stream to TTS service.
Only streaming content is allowed! Only streaming content is allowed!
""" """
self.__last_tts_loaded = dt_util.utcnow().isoformat() self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state() self.async_write_ha_state()
return await self.async_get_tts_audio( return await self.async_stream_tts_audio(request)
message=message, language=language, options=options
async def async_stream_tts_audio(
self, request: TTSAudioRequest
) -> TTSAudioResponse:
"""Generate speech from an incoming message.
The default implementation is backwards compatible with async_get_tts_audio.
"""
message = "".join([chunk async for chunk in request.message_gen])
extension, data = await self.async_get_tts_audio(
message, request.language, request.options
) )
if extension is None or data is None:
raise HomeAssistantError(f"No TTS from {self.entity_id} for '{message}'")
async def data_gen() -> AsyncGenerator[bytes]:
yield data
return TTSAudioResponse(extension, data_gen())
def get_tts_audio( def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType: