mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Enable TTS streaming implementations (#140176)
* Enable TTS streaming implementations * Update comment * Revert type change
This commit is contained in:
parent
d498dbd5ac
commit
1665d9474f
@ -62,7 +62,7 @@ from .const import (
|
||||
DOMAIN,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .entity import TextToSpeechEntity
|
||||
from .entity import TextToSpeechEntity, TTSAudioRequest
|
||||
from .helper import get_engine_instance
|
||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||
@ -795,9 +795,15 @@ class SpeechManager:
|
||||
message, language, options
|
||||
)
|
||||
else:
|
||||
extension, data = await engine_instance.internal_async_get_tts_audio(
|
||||
message, language, options
|
||||
|
||||
async def message_gen() -> AsyncGenerator[str]:
|
||||
yield message
|
||||
|
||||
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||
TTSAudioRequest(language, options, message_gen())
|
||||
)
|
||||
extension = tts_result.extension
|
||||
data = b"".join([chunk async for chunk in tts_result.data_gen])
|
||||
|
||||
if data is None or extension is None:
|
||||
raise HomeAssistantError(
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Entity for Text-to-Speech."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, final
|
||||
|
||||
@ -16,6 +17,7 @@ from homeassistant.components.media_player import (
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
@ -31,6 +33,23 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSAudioRequest:
|
||||
"""Request to get TTS audio."""
|
||||
|
||||
language: str
|
||||
options: dict[str, Any]
|
||||
message_gen: AsyncGenerator[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSAudioResponse:
|
||||
"""Response containing TTS audio stream."""
|
||||
|
||||
extension: str
|
||||
data_gen: AsyncGenerator[bytes]
|
||||
|
||||
|
||||
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
|
||||
"""Represent a single TTS engine."""
|
||||
|
||||
@ -128,19 +147,37 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
||||
)
|
||||
|
||||
@final
|
||||
async def internal_async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
async def internal_async_stream_tts_audio(
|
||||
self, request: TTSAudioRequest
|
||||
) -> TTSAudioResponse:
|
||||
"""Process an audio stream to TTS service.
|
||||
|
||||
Only streaming content is allowed!
|
||||
"""
|
||||
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self.async_get_tts_audio(
|
||||
message=message, language=language, options=options
|
||||
return await self.async_stream_tts_audio(request)
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
self, request: TTSAudioRequest
|
||||
) -> TTSAudioResponse:
|
||||
"""Generate speech from an incoming message.
|
||||
|
||||
The default implementation is backwards compatible with async_get_tts_audio.
|
||||
"""
|
||||
message = "".join([chunk async for chunk in request.message_gen])
|
||||
extension, data = await self.async_get_tts_audio(
|
||||
message, request.language, request.options
|
||||
)
|
||||
|
||||
if extension is None or data is None:
|
||||
raise HomeAssistantError(f"No TTS from {self.entity_id} for '{message}'")
|
||||
|
||||
async def data_gen() -> AsyncGenerator[bytes]:
|
||||
yield data
|
||||
|
||||
return TTSAudioResponse(extension, data_gen())
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
|
Loading…
x
Reference in New Issue
Block a user