mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
TTS to only use stream entity method when streaming request comes in (#145167)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
1f6faaacab
commit
e78f4d2a29
@ -852,12 +852,9 @@ class SpeechManager:
|
|||||||
else:
|
else:
|
||||||
_LOGGER.debug("Generating audio for %s", message[0:32])
|
_LOGGER.debug("Generating audio for %s", message[0:32])
|
||||||
|
|
||||||
async def message_stream() -> AsyncGenerator[str]:
|
|
||||||
yield message
|
|
||||||
|
|
||||||
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT)
|
||||||
data_gen = self._async_generate_tts_audio(
|
data_gen = self._async_generate_tts_audio(
|
||||||
engine_instance, message_stream(), language, options
|
engine_instance, message, language, options
|
||||||
)
|
)
|
||||||
|
|
||||||
cache = TTSCache(
|
cache = TTSCache(
|
||||||
@ -931,7 +928,7 @@ class SpeechManager:
|
|||||||
async def _async_generate_tts_audio(
|
async def _async_generate_tts_audio(
|
||||||
self,
|
self,
|
||||||
engine_instance: TextToSpeechEntity | Provider,
|
engine_instance: TextToSpeechEntity | Provider,
|
||||||
message_stream: AsyncGenerator[str],
|
message_or_stream: str | AsyncGenerator[str],
|
||||||
language: str,
|
language: str,
|
||||||
options: dict[str, Any],
|
options: dict[str, Any],
|
||||||
) -> AsyncGenerator[bytes]:
|
) -> AsyncGenerator[bytes]:
|
||||||
@ -979,9 +976,12 @@ class SpeechManager:
|
|||||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||||
raise HomeAssistantError("TTS engine name is not set.")
|
raise HomeAssistantError("TTS engine name is not set.")
|
||||||
|
|
||||||
if isinstance(engine_instance, Provider):
|
if isinstance(engine_instance, Provider) or isinstance(message_or_stream, str):
|
||||||
message = "".join([chunk async for chunk in message_stream])
|
if isinstance(message_or_stream, str):
|
||||||
extension, data = await engine_instance.async_get_tts_audio(
|
message = message_or_stream
|
||||||
|
else:
|
||||||
|
message = "".join([chunk async for chunk in message_or_stream])
|
||||||
|
extension, data = await engine_instance.async_internal_get_tts_audio(
|
||||||
message, language, options
|
message, language, options
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -997,7 +997,7 @@ class SpeechManager:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
tts_result = await engine_instance.internal_async_stream_tts_audio(
|
||||||
TTSAudioRequest(language, options, message_stream)
|
TTSAudioRequest(language, options, message_or_stream)
|
||||||
)
|
)
|
||||||
extension = tts_result.extension
|
extension = tts_result.extension
|
||||||
data_gen = tts_result.data_gen
|
data_gen = tts_result.data_gen
|
||||||
|
@ -165,6 +165,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
|||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
return await self.async_stream_tts_audio(request)
|
return await self.async_stream_tts_audio(request)
|
||||||
|
|
||||||
|
@final
|
||||||
|
async def async_internal_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load tts audio file from the engine and update state.
|
||||||
|
|
||||||
|
Return a tuple of file extension and data as bytes.
|
||||||
|
"""
|
||||||
|
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||||
|
self.async_write_ha_state()
|
||||||
|
return await self.async_get_tts_audio(message, language, options=options)
|
||||||
|
|
||||||
async def async_stream_tts_audio(
|
async def async_stream_tts_audio(
|
||||||
self, request: TTSAudioRequest
|
self, request: TTSAudioRequest
|
||||||
) -> TTSAudioResponse:
|
) -> TTSAudioResponse:
|
||||||
|
@ -7,7 +7,7 @@ from collections.abc import Coroutine, Mapping
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, final
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -252,3 +252,15 @@ class Provider:
|
|||||||
return await self.hass.async_add_executor_job(
|
return await self.hass.async_add_executor_job(
|
||||||
partial(self.get_tts_audio, message, language, options=options)
|
partial(self.get_tts_audio, message, language, options=options)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@final
|
||||||
|
async def async_internal_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load tts audio file from provider.
|
||||||
|
|
||||||
|
Proxies request to mimic the entity interface.
|
||||||
|
|
||||||
|
Return a tuple of file extension and data as bytes.
|
||||||
|
"""
|
||||||
|
return await self.async_get_tts_audio(message, language, options)
|
||||||
|
@ -1627,25 +1627,15 @@ async def test_chat_log_tts_streaming(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
received_tts = []
|
async def async_get_tts_audio(
|
||||||
|
message: str,
|
||||||
async def async_stream_tts_audio(
|
language: str,
|
||||||
request: tts.TTSAudioRequest,
|
options: dict[str, Any] | None = None,
|
||||||
) -> tts.TTSAudioResponse:
|
) -> tts.TTSAudioResponse:
|
||||||
"""Mock stream TTS audio."""
|
"""Mock get TTS audio."""
|
||||||
|
return ("mp3", b"".join([chunk.encode() for chunk in to_stream_tts]))
|
||||||
|
|
||||||
async def gen_data():
|
mock_tts_entity.async_get_tts_audio = async_get_tts_audio
|
||||||
async for msg in request.message_gen:
|
|
||||||
received_tts.append(msg)
|
|
||||||
yield msg.encode()
|
|
||||||
|
|
||||||
return tts.TTSAudioResponse(
|
|
||||||
extension="mp3",
|
|
||||||
data_gen=gen_data(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
|
||||||
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
@ -1717,7 +1707,6 @@ async def test_chat_log_tts_streaming(
|
|||||||
|
|
||||||
streamed_text = "".join(to_stream_tts)
|
streamed_text = "".join(to_stream_tts)
|
||||||
assert tts_result == streamed_text
|
assert tts_result == streamed_text
|
||||||
assert len(received_tts) == expected_chunks
|
assert expected_chunks == 1
|
||||||
assert "".join(received_tts) == streamed_text
|
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
assert process_events(events) == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user