Enable TTS streaming implementations (#140176)

* Enable TTS streaming implementations

* Update comment

* Revert type change
This commit is contained in:
Paulus Schoutsen 2025-03-10 15:12:37 -04:00 committed by GitHub
parent d498dbd5ac
commit 1665d9474f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 9 deletions

View File

@ -62,7 +62,7 @@ from .const import (
DOMAIN,
TtsAudioType,
)
from .entity import TextToSpeechEntity
from .entity import TextToSpeechEntity, TTSAudioRequest
from .helper import get_engine_instance
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
from .media_source import generate_media_source_id, media_source_id_to_kwargs
@ -795,9 +795,15 @@ class SpeechManager:
message, language, options
)
else:
extension, data = await engine_instance.internal_async_get_tts_audio(
message, language, options
async def message_gen() -> AsyncGenerator[str]:
yield message
tts_result = await engine_instance.internal_async_stream_tts_audio(
TTSAudioRequest(language, options, message_gen())
)
extension = tts_result.extension
data = b"".join([chunk async for chunk in tts_result.data_gen])
if data is None or extension is None:
raise HomeAssistantError(

View File

@ -1,6 +1,7 @@
"""Entity for Text-to-Speech."""
from collections.abc import Mapping
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass
from functools import partial
from typing import Any, final
@ -16,6 +17,7 @@ from homeassistant.components.media_player import (
)
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
@ -31,6 +33,23 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
}
@dataclass
class TTSAudioRequest:
"""Request to get TTS audio."""
language: str
options: dict[str, Any]
message_gen: AsyncGenerator[str]
@dataclass
class TTSAudioResponse:
"""Response containing TTS audio stream."""
extension: str
data_gen: AsyncGenerator[bytes]
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
"""Represent a single TTS engine."""
@ -128,19 +147,37 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
)
@final
async def internal_async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
async def internal_async_stream_tts_audio(
self, request: TTSAudioRequest
) -> TTSAudioResponse:
"""Process an audio stream to TTS service.
Only streaming content is allowed!
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(
message=message, language=language, options=options
return await self.async_stream_tts_audio(request)
async def async_stream_tts_audio(
self, request: TTSAudioRequest
) -> TTSAudioResponse:
"""Generate speech from an incoming message.
The default implementation is backwards compatible with async_get_tts_audio.
"""
message = "".join([chunk async for chunk in request.message_gen])
extension, data = await self.async_get_tts_audio(
message, request.language, request.options
)
if extension is None or data is None:
raise HomeAssistantError(f"No TTS from {self.entity_id} for '{message}'")
async def data_gen() -> AsyncGenerator[bytes]:
yield data
return TTSAudioResponse(extension, data_gen())
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: