From 4994229215979f6a150fc3cdc2950e0f305a77b2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 12 May 2025 13:44:39 -0400 Subject: [PATCH] Track if TTS entity supports streaming input (#144697) * Track if entity supports streaming * Make class method --- homeassistant/components/tts/__init__.py | 16 ++++++++++-- homeassistant/components/tts/entity.py | 7 +++++ tests/components/tts/common.py | 1 + tests/components/tts/test_entity.py | 33 ++++++++++++++++++++++++ tests/components/tts/test_init.py | 5 +++- 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index b279af31803..526be21ad76 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -469,6 +469,7 @@ class ResultStream: use_file_cache: bool language: str options: dict + supports_streaming_input: bool _manager: SpeechManager @@ -484,7 +485,10 @@ class ResultStream: @callback def async_set_message(self, message: str) -> None: - """Set message to be generated.""" + """Set message to be generated. + + This method will leverage a disk cache to speed up generation. + """ self._result_cache.set_result( self._manager.async_cache_message_in_memory( engine=self.engine, @@ -497,7 +501,10 @@ class ResultStream: @callback def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None: - """Set a stream that will generate the message.""" + """Set a stream that will generate the message. + + This method can result in faster first byte when generating long responses. + """ self._result_cache.set_result( self._manager.async_cache_message_stream_in_memory( engine=self.engine, @@ -726,6 +733,10 @@ class SpeechManager: if (engine_instance := get_engine_instance(self.hass, engine)) is None: raise HomeAssistantError(f"Provider {engine} not found") + supports_streaming_input = ( + isinstance(engine_instance, TextToSpeechEntity) + and engine_instance.async_supports_streaming_input() + ) language, options = self.process_options(engine_instance, language, options) if use_file_cache is None: use_file_cache = self.use_file_cache @@ -741,6 +752,7 @@ class SpeechManager: engine=engine, language=language, options=options, + supports_streaming_input=supports_streaming_input, _manager=self, ) self.token_to_stream[token] = result_stream diff --git a/homeassistant/components/tts/entity.py b/homeassistant/components/tts/entity.py index 199d673398e..1f01a41c5ab 100644 --- a/homeassistant/components/tts/entity.py +++ b/homeassistant/components/tts/entity.py @@ -89,6 +89,13 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH """Return a mapping with the default options.""" return self._attr_default_options + @classmethod + def async_supports_streaming_input(cls) -> bool: + """Return if the TTS engine supports streaming input.""" + return ( + cls.async_stream_tts_audio is not TextToSpeechEntity.async_stream_tts_audio + ) + @callback def async_get_supported_voices(self, language: str) -> list[Voice] | None: """Return a list of supported voices for a language.""" diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index c21db66dfac..171334c136a 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -281,6 +281,7 @@ class MockResultStream(ResultStream): content_type=f"audio/mock-{extension}", engine="test-engine", use_file_cache=True, + supports_streaming_input=True, language="en", options={}, _manager=hass.data[DATA_TTS_MANAGER], diff --git a/tests/components/tts/test_entity.py b/tests/components/tts/test_entity.py index d82ec6a5d2b..8648ca95e93 100644 --- a/tests/components/tts/test_entity.py +++ b/tests/components/tts/test_entity.py @@ -1,5 +1,7 @@ """Tests for the TTS entity.""" +from typing import Any + import pytest from homeassistant.components import tts @@ -142,3 +144,34 @@ async def test_tts_entity_subclass_properties( if record.exc_info is not None ] ) + + +def test_streaming_supported() -> None: + """Test streaming support.""" + base_entity = tts.TextToSpeechEntity() + assert base_entity.async_supports_streaming_input() is False + + class StreamingEntity(tts.TextToSpeechEntity): + async def async_stream_tts_audio(self) -> None: + pass + + streaming_entity = StreamingEntity() + assert streaming_entity.async_supports_streaming_input() is True + + class NonStreamingEntity(tts.TextToSpeechEntity): + async def async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> tts.TtsAudioType: + pass + + non_streaming_entity = NonStreamingEntity() + assert non_streaming_entity.async_supports_streaming_input() is False + + class SyncNonStreamingEntity(tts.TextToSpeechEntity): + def get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> tts.TtsAudioType: + pass + + sync_non_streaming_entity = SyncNonStreamingEntity() + assert sync_non_streaming_entity.async_supports_streaming_input() is False diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index ea281506f3a..ccb62959eba 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -4,7 +4,7 @@ import asyncio from http import HTTPStatus from pathlib import Path from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch from freezegun.api import FrozenDateTimeFactory import pytest @@ -1885,6 +1885,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) assert stream.language == mock_tts_entity.default_language assert stream.options == (mock_tts_entity.default_options or {}) + assert stream.supports_streaming_input is False assert tts.async_get_stream(hass, stream.token) is stream stream.async_set_message("beer") result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) @@ -1905,6 +1906,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No ) mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio + mock_tts_entity.async_supports_streaming_input = Mock(return_value=True) async def stream_message(): """Mock stream message.""" @@ -1913,6 +1915,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No yield "o" stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) + assert stream.supports_streaming_input is True stream.async_set_message_stream(stream_message()) result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) assert result_data == b"hello"