mirror of
https://github.com/home-assistant/core.git
synced 2025-07-26 22:57:17 +00:00
Enable TTS streaming implementations (#140176)
* Enable TTS streaming implementations * Update comment * Revert type change
This commit is contained in:
parent
d498dbd5ac
commit
1665d9474f
@ -62,7 +62,7 @@ from .const import (
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
TtsAudioType,
|
TtsAudioType,
|
||||||
)
|
)
|
||||||
from .entity import TextToSpeechEntity
|
from .entity import TextToSpeechEntity, TTSAudioRequest
|
||||||
from .helper import get_engine_instance
|
from .helper import get_engine_instance
|
||||||
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy
|
||||||
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
from .media_source import generate_media_source_id, media_source_id_to_kwargs
|
||||||
@ -795,9 +795,15 @@ class SpeechManager:
|
|||||||
message, language, options
|
message, language, options
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
extension, data = await engine_instance.internal_async_get_tts_audio(
|
|
||||||
message, language, options
|
async def message_gen() -> AsyncGenerator[str]:
|
||||||
|
yield message
|
||||||
|
|
||||||
|
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||||
|
TTSAudioRequest(language, options, message_gen())
|
||||||
)
|
)
|
||||||
|
extension = tts_result.extension
|
||||||
|
data = b"".join([chunk async for chunk in tts_result.data_gen])
|
||||||
|
|
||||||
if data is None or extension is None:
|
if data is None or extension is None:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Entity for Text-to-Speech."""
|
"""Entity for Text-to-Speech."""
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, final
|
from typing import Any, final
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ from homeassistant.components.media_player import (
|
|||||||
)
|
)
|
||||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
|
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.restore_state import RestoreEntity
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
@ -31,6 +33,23 @@ CACHED_PROPERTIES_WITH_ATTR_ = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSAudioRequest:
|
||||||
|
"""Request to get TTS audio."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
options: dict[str, Any]
|
||||||
|
message_gen: AsyncGenerator[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSAudioResponse:
|
||||||
|
"""Response containing TTS audio stream."""
|
||||||
|
|
||||||
|
extension: str
|
||||||
|
data_gen: AsyncGenerator[bytes]
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
|
class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
|
||||||
"""Represent a single TTS engine."""
|
"""Represent a single TTS engine."""
|
||||||
|
|
||||||
@ -128,19 +147,37 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
|||||||
)
|
)
|
||||||
|
|
||||||
@final
|
@final
|
||||||
async def internal_async_get_tts_audio(
|
async def internal_async_stream_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any]
|
self, request: TTSAudioRequest
|
||||||
) -> TtsAudioType:
|
) -> TTSAudioResponse:
|
||||||
"""Process an audio stream to TTS service.
|
"""Process an audio stream to TTS service.
|
||||||
|
|
||||||
Only streaming content is allowed!
|
Only streaming content is allowed!
|
||||||
"""
|
"""
|
||||||
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
return await self.async_get_tts_audio(
|
return await self.async_stream_tts_audio(request)
|
||||||
message=message, language=language, options=options
|
|
||||||
|
async def async_stream_tts_audio(
|
||||||
|
self, request: TTSAudioRequest
|
||||||
|
) -> TTSAudioResponse:
|
||||||
|
"""Generate speech from an incoming message.
|
||||||
|
|
||||||
|
The default implementation is backwards compatible with async_get_tts_audio.
|
||||||
|
"""
|
||||||
|
message = "".join([chunk async for chunk in request.message_gen])
|
||||||
|
extension, data = await self.async_get_tts_audio(
|
||||||
|
message, request.language, request.options
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extension is None or data is None:
|
||||||
|
raise HomeAssistantError(f"No TTS from {self.entity_id} for '{message}'")
|
||||||
|
|
||||||
|
async def data_gen() -> AsyncGenerator[bytes]:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return TTSAudioResponse(extension, data_gen())
|
||||||
|
|
||||||
def get_tts_audio(
|
def get_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any]
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user