mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
TTS to only use stream entity method when streaming request comes in (#145167)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
1f6faaacab
commit
e78f4d2a29
@ -852,12 +852,9 @@ class SpeechManager:
|
||||
else:
|
||||
_LOGGER.debug("Generating audio for %s", message[0:32])
|
||||
|
||||
async def message_stream() -> AsyncGenerator[str]:
|
||||
yield message
|
||||
|
||||
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||
data_gen = self._async_generate_tts_audio(
|
||||
engine_instance, message_stream(), language, options
|
||||
engine_instance, message, language, options
|
||||
)
|
||||
|
||||
cache = TTSCache(
|
||||
@ -931,7 +928,7 @@ class SpeechManager:
|
||||
async def _async_generate_tts_audio(
|
||||
self,
|
||||
engine_instance: TextToSpeechEntity | Provider,
|
||||
message_stream: AsyncGenerator[str],
|
||||
message_or_stream: str | AsyncGenerator[str],
|
||||
language: str,
|
||||
options: dict[str, Any],
|
||||
) -> AsyncGenerator[bytes]:
|
||||
@ -979,9 +976,12 @@ class SpeechManager:
|
||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||
raise HomeAssistantError("TTS engine name is not set.")
|
||||
|
||||
if isinstance(engine_instance, Provider):
|
||||
message = "".join([chunk async for chunk in message_stream])
|
||||
extension, data = await engine_instance.async_get_tts_audio(
|
||||
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
|
||||
if isinstance(message_or_stream, str):
|
||||
message = message_or_stream
|
||||
else:
|
||||
message = "".join([chunk async for chunk in message_or_stream])
|
||||
extension, data = await engine_instance.async_internal_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
|
||||
@ -997,7 +997,7 @@ class SpeechManager:
|
||||
|
||||
else:
|
||||
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||
TTSAudioRequest(language, options, message_stream)
|
||||
TTSAudioRequest(language, options, message_or_stream)
|
||||
)
|
||||
extension = tts_result.extension
|
||||
data_gen = tts_result.data_gen
|
||||
|
@ -165,6 +165,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
||||
self.async_write_ha_state()
|
||||
return await self.async_stream_tts_audio(request)
|
||||
|
||||
@final
|
||||
async def async_internal_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine and update state.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self.async_get_tts_audio(message, language, options=options)
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
self, request: TTSAudioRequest
|
||||
) -> TTSAudioResponse:
|
||||
|
@ -7,7 +7,7 @@ from collections.abc import Coroutine, Mapping
|
||||
from functools import partial
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -252,3 +252,15 @@ class Provider:
|
||||
return await self.hass.async_add_executor_job(
|
||||
partial(self.get_tts_audio, message, language, options=options)
|
||||
)
|
||||
|
||||
@final
|
||||
async def async_internal_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
Proxies request to mimic the entity interface.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
return await self.async_get_tts_audio(message, language, options)
|
||||
|
@ -1627,25 +1627,15 @@ async def test_chat_log_tts_streaming(
|
||||
),
|
||||
)
|
||||
|
||||
received_tts = []
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
request: tts.TTSAudioRequest,
|
||||
async def async_get_tts_audio(
|
||||
message: str,
|
||||
language: str,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tts.TTSAudioResponse:
|
||||
"""Mock stream TTS audio."""
|
||||
"""Mock get TTS audio."""
|
||||
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
||||
|
||||
async def gen_data():
|
||||
async for msg in request.message_gen:
|
||||
received_tts.append(msg)
|
||||
yield msg.encode()
|
||||
|
||||
return tts.TTSAudioResponse(
|
||||
extension="mp3",
|
||||
data_gen=gen_data(),
|
||||
)
|
||||
|
||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
||||
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
@ -1717,7 +1707,6 @@ async def test_chat_log_tts_streaming(
|
||||
|
||||
streamed_text = "".join(to_stream_tts)
|
||||
assert tts_result == streamed_text
|
||||
assert len(received_tts) == expected_chunks
|
||||
assert "".join(received_tts) == streamed_text
|
||||
assert expected_chunks == 1
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user