mirror of
https://github.com/home-assistant/core.git
synced 2025-06-10 08:07:06 +00:00
Allow streaming text into TTS ResultStream objects (#143745)
Allow streaming messages into TTS ResultStream
This commit is contained in:
parent
ae118da5a1
commit
5dab9ba01b
@ -42,7 +42,7 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
from homeassistant.helpers.network import get_url
|
from homeassistant.helpers.network import get_url
|
||||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
||||||
from homeassistant.util import language as language_util
|
from homeassistant.util import language as language_util, ulid as ulid_util
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_CACHE,
|
ATTR_CACHE,
|
||||||
@ -495,6 +495,18 @@ class ResultStream:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None:
|
||||||
|
"""Set a stream that will generate the message."""
|
||||||
|
self._result_cache.set_result(
|
||||||
|
self._manager.async_cache_message_stream_in_memory(
|
||||||
|
engine=self.engine,
|
||||||
|
message_stream=message_stream,
|
||||||
|
language=self.language,
|
||||||
|
options=self.options,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||||
"""Get the stream of this result."""
|
"""Get the stream of this result."""
|
||||||
cache = await self._result_cache
|
cache = await self._result_cache
|
||||||
@ -735,6 +747,42 @@ class SpeechManager:
|
|||||||
self.token_to_stream_cleanup.schedule()
|
self.token_to_stream_cleanup.schedule()
|
||||||
return result_stream
|
return result_stream
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_cache_message_stream_in_memory(
|
||||||
|
self,
|
||||||
|
engine: str,
|
||||||
|
message_stream: AsyncGenerator[str],
|
||||||
|
language: str,
|
||||||
|
options: dict,
|
||||||
|
) -> TTSCache:
|
||||||
|
"""Make sure a message stream will be cached in memory and returns cache object.
|
||||||
|
|
||||||
|
Requires options, language to be processed.
|
||||||
|
"""
|
||||||
|
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||||
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
|
|
||||||
|
cache_key = ulid_util.ulid_now()
|
||||||
|
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||||
|
data_gen = self._async_generate_tts_audio(
|
||||||
|
engine_instance, message_stream, language, options
|
||||||
|
)
|
||||||
|
|
||||||
|
cache = TTSCache(
|
||||||
|
cache_key=cache_key,
|
||||||
|
extension=extension,
|
||||||
|
data_gen=data_gen,
|
||||||
|
)
|
||||||
|
self.mem_cache[cache_key] = cache
|
||||||
|
self.hass.async_create_background_task(
|
||||||
|
self._load_data_into_cache(
|
||||||
|
cache, engine_instance, "[Streaming TTS]", False, language, options
|
||||||
|
),
|
||||||
|
f"tts_load_data_into_cache_{engine_instance.name}",
|
||||||
|
)
|
||||||
|
self.memcache_cleanup.schedule()
|
||||||
|
return cache
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_cache_message_in_memory(
|
def async_cache_message_in_memory(
|
||||||
self,
|
self,
|
||||||
|
@ -1842,6 +1842,7 @@ async def test_default_engine_prefer_cloud_entity(
|
|||||||
async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> None:
|
async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> None:
|
||||||
"""Test creating streams."""
|
"""Test creating streams."""
|
||||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||||
|
|
||||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||||
assert stream.language == mock_tts_entity.default_language
|
assert stream.language == mock_tts_entity.default_language
|
||||||
assert stream.options == (mock_tts_entity.default_options or {})
|
assert stream.options == (mock_tts_entity.default_options or {})
|
||||||
@ -1850,6 +1851,33 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
|||||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
assert result_data == MOCK_DATA
|
assert result_data == MOCK_DATA
|
||||||
|
|
||||||
|
async def async_stream_tts_audio(
|
||||||
|
request: tts.TTSAudioRequest,
|
||||||
|
) -> tts.TTSAudioResponse:
|
||||||
|
"""Mock stream TTS audio."""
|
||||||
|
|
||||||
|
async def gen_data():
|
||||||
|
async for msg in request.message_gen:
|
||||||
|
yield msg.encode()
|
||||||
|
|
||||||
|
return tts.TTSAudioResponse(
|
||||||
|
extension="mp3",
|
||||||
|
data_gen=gen_data(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||||
|
|
||||||
|
async def stream_message():
|
||||||
|
"""Mock stream message."""
|
||||||
|
yield "he"
|
||||||
|
yield "ll"
|
||||||
|
yield "o"
|
||||||
|
|
||||||
|
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||||
|
stream.async_set_message_stream(stream_message())
|
||||||
|
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
|
assert result_data == b"hello"
|
||||||
|
|
||||||
data = b"beer"
|
data = b"beer"
|
||||||
stream2 = MockResultStream(hass, "wav", data)
|
stream2 = MockResultStream(hass, "wav", data)
|
||||||
assert tts.async_get_stream(hass, stream2.token) is stream2
|
assert tts.async_get_stream(hass, stream2.token) is stream2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user