Track if TTS entity supports streaming input (#144697)

* Track if entity supports streaming

* Make class method
This commit is contained in:
Paulus Schoutsen 2025-05-12 13:44:39 -04:00 committed by GitHub
parent c022c32d2f
commit 4994229215
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 3 deletions

View File

@ -469,6 +469,7 @@ class ResultStream:
use_file_cache: bool use_file_cache: bool
language: str language: str
options: dict options: dict
supports_streaming_input: bool
_manager: SpeechManager _manager: SpeechManager
@ -484,7 +485,10 @@ class ResultStream:
@callback @callback
def async_set_message(self, message: str) -> None: def async_set_message(self, message: str) -> None:
"""Set message to be generated.""" """Set message to be generated.
This method will leverage a disk cache to speed up generation.
"""
self._result_cache.set_result( self._result_cache.set_result(
self._manager.async_cache_message_in_memory( self._manager.async_cache_message_in_memory(
engine=self.engine, engine=self.engine,
@ -497,7 +501,10 @@ class ResultStream:
@callback @callback
def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None: def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None:
"""Set a stream that will generate the message.""" """Set a stream that will generate the message.
This method can result in faster first byte when generating long responses.
"""
self._result_cache.set_result( self._result_cache.set_result(
self._manager.async_cache_message_stream_in_memory( self._manager.async_cache_message_stream_in_memory(
engine=self.engine, engine=self.engine,
@ -726,6 +733,10 @@ class SpeechManager:
if (engine_instance := get_engine_instance(self.hass, engine)) is None: if (engine_instance := get_engine_instance(self.hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found") raise HomeAssistantError(f"Provider {engine} not found")
supports_streaming_input = (
isinstance(engine_instance, TextToSpeechEntity)
and engine_instance.async_supports_streaming_input()
)
language, options = self.process_options(engine_instance, language, options) language, options = self.process_options(engine_instance, language, options)
if use_file_cache is None: if use_file_cache is None:
use_file_cache = self.use_file_cache use_file_cache = self.use_file_cache
@ -741,6 +752,7 @@ class SpeechManager:
engine=engine, engine=engine,
language=language, language=language,
options=options, options=options,
supports_streaming_input=supports_streaming_input,
_manager=self, _manager=self,
) )
self.token_to_stream[token] = result_stream self.token_to_stream[token] = result_stream

View File

@ -89,6 +89,13 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
"""Return a mapping with the default options.""" """Return a mapping with the default options."""
return self._attr_default_options return self._attr_default_options
@classmethod
def async_supports_streaming_input(cls) -> bool:
"""Return if the TTS engine supports streaming input."""
return (
cls.async_stream_tts_audio is not TextToSpeechEntity.async_stream_tts_audio
)
@callback @callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None: def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return a list of supported voices for a language.""" """Return a list of supported voices for a language."""

View File

@ -281,6 +281,7 @@ class MockResultStream(ResultStream):
content_type=f"audio/mock-{extension}", content_type=f"audio/mock-{extension}",
engine="test-engine", engine="test-engine",
use_file_cache=True, use_file_cache=True,
supports_streaming_input=True,
language="en", language="en",
options={}, options={},
_manager=hass.data[DATA_TTS_MANAGER], _manager=hass.data[DATA_TTS_MANAGER],

View File

@ -1,5 +1,7 @@
"""Tests for the TTS entity.""" """Tests for the TTS entity."""
from typing import Any
import pytest import pytest
from homeassistant.components import tts from homeassistant.components import tts
@ -142,3 +144,34 @@ async def test_tts_entity_subclass_properties(
if record.exc_info is not None if record.exc_info is not None
] ]
) )
def test_streaming_supported() -> None:
"""Test streaming support."""
base_entity = tts.TextToSpeechEntity()
assert base_entity.async_supports_streaming_input() is False
class StreamingEntity(tts.TextToSpeechEntity):
async def async_stream_tts_audio(self) -> None:
pass
streaming_entity = StreamingEntity()
assert streaming_entity.async_supports_streaming_input() is True
class NonStreamingEntity(tts.TextToSpeechEntity):
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
pass
non_streaming_entity = NonStreamingEntity()
assert non_streaming_entity.async_supports_streaming_input() is False
class SyncNonStreamingEntity(tts.TextToSpeechEntity):
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> tts.TtsAudioType:
pass
sync_non_streaming_entity = SyncNonStreamingEntity()
assert sync_non_streaming_entity.async_supports_streaming_input() is False

View File

@ -4,7 +4,7 @@ import asyncio
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, Mock, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
@ -1885,6 +1885,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
assert stream.language == mock_tts_entity.default_language assert stream.language == mock_tts_entity.default_language
assert stream.options == (mock_tts_entity.default_options or {}) assert stream.options == (mock_tts_entity.default_options or {})
assert stream.supports_streaming_input is False
assert tts.async_get_stream(hass, stream.token) is stream assert tts.async_get_stream(hass, stream.token) is stream
stream.async_set_message("beer") stream.async_set_message("beer")
result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
@ -1905,6 +1906,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
) )
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
async def stream_message(): async def stream_message():
"""Mock stream message.""" """Mock stream message."""
@ -1913,6 +1915,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
yield "o" yield "o"
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
assert stream.supports_streaming_input is True
stream.async_set_message_stream(stream_message()) stream.async_set_message_stream(stream_message())
result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
assert result_data == b"hello" assert result_data == b"hello"