mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 08:07:45 +00:00
Track if TTS entity supports streaming input (#144697)
* Track if entity supports streaming * Make class method
This commit is contained in:
parent
c022c32d2f
commit
4994229215
@ -469,6 +469,7 @@ class ResultStream:
|
||||
use_file_cache: bool
|
||||
language: str
|
||||
options: dict
|
||||
supports_streaming_input: bool
|
||||
|
||||
_manager: SpeechManager
|
||||
|
||||
@ -484,7 +485,10 @@ class ResultStream:
|
||||
|
||||
@callback
|
||||
def async_set_message(self, message: str) -> None:
|
||||
"""Set message to be generated."""
|
||||
"""Set message to be generated.
|
||||
|
||||
This method will leverage a disk cache to speed up generation.
|
||||
"""
|
||||
self._result_cache.set_result(
|
||||
self._manager.async_cache_message_in_memory(
|
||||
engine=self.engine,
|
||||
@ -497,7 +501,10 @@ class ResultStream:
|
||||
|
||||
@callback
|
||||
def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None:
|
||||
"""Set a stream that will generate the message."""
|
||||
"""Set a stream that will generate the message.
|
||||
|
||||
This method can result in faster first byte when generating long responses.
|
||||
"""
|
||||
self._result_cache.set_result(
|
||||
self._manager.async_cache_message_stream_in_memory(
|
||||
engine=self.engine,
|
||||
@ -726,6 +733,10 @@ class SpeechManager:
|
||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
supports_streaming_input = (
|
||||
isinstance(engine_instance, TextToSpeechEntity)
|
||||
and engine_instance.async_supports_streaming_input()
|
||||
)
|
||||
language, options = self.process_options(engine_instance, language, options)
|
||||
if use_file_cache is None:
|
||||
use_file_cache = self.use_file_cache
|
||||
@ -741,6 +752,7 @@ class SpeechManager:
|
||||
engine=engine,
|
||||
language=language,
|
||||
options=options,
|
||||
supports_streaming_input=supports_streaming_input,
|
||||
_manager=self,
|
||||
)
|
||||
self.token_to_stream[token] = result_stream
|
||||
|
@ -89,6 +89,13 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
||||
"""Return a mapping with the default options."""
|
||||
return self._attr_default_options
|
||||
|
||||
@classmethod
|
||||
def async_supports_streaming_input(cls) -> bool:
|
||||
"""Return if the TTS engine supports streaming input."""
|
||||
return (
|
||||
cls.async_stream_tts_audio is not TextToSpeechEntity.async_stream_tts_audio
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
|
@ -281,6 +281,7 @@ class MockResultStream(ResultStream):
|
||||
content_type=f"audio/mock-{extension}",
|
||||
engine="test-engine",
|
||||
use_file_cache=True,
|
||||
supports_streaming_input=True,
|
||||
language="en",
|
||||
options={},
|
||||
_manager=hass.data[DATA_TTS_MANAGER],
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Tests for the TTS entity."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import tts
|
||||
@ -142,3 +144,34 @@ async def test_tts_entity_subclass_properties(
|
||||
if record.exc_info is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_streaming_supported() -> None:
|
||||
"""Test streaming support."""
|
||||
base_entity = tts.TextToSpeechEntity()
|
||||
assert base_entity.async_supports_streaming_input() is False
|
||||
|
||||
class StreamingEntity(tts.TextToSpeechEntity):
|
||||
async def async_stream_tts_audio(self) -> None:
|
||||
pass
|
||||
|
||||
streaming_entity = StreamingEntity()
|
||||
assert streaming_entity.async_supports_streaming_input() is True
|
||||
|
||||
class NonStreamingEntity(tts.TextToSpeechEntity):
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
pass
|
||||
|
||||
non_streaming_entity = NonStreamingEntity()
|
||||
assert non_streaming_entity.async_supports_streaming_input() is False
|
||||
|
||||
class SyncNonStreamingEntity(tts.TextToSpeechEntity):
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> tts.TtsAudioType:
|
||||
pass
|
||||
|
||||
sync_non_streaming_entity = SyncNonStreamingEntity()
|
||||
assert sync_non_streaming_entity.async_supports_streaming_input() is False
|
||||
|
@ -4,7 +4,7 @@ import asyncio
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
@ -1885,6 +1885,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||
assert stream.language == mock_tts_entity.default_language
|
||||
assert stream.options == (mock_tts_entity.default_options or {})
|
||||
assert stream.supports_streaming_input is False
|
||||
assert tts.async_get_stream(hass, stream.token) is stream
|
||||
stream.async_set_message("beer")
|
||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
@ -1905,6 +1906,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
||||
)
|
||||
|
||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
||||
|
||||
async def stream_message():
|
||||
"""Mock stream message."""
|
||||
@ -1913,6 +1915,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
||||
yield "o"
|
||||
|
||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||
assert stream.supports_streaming_input is True
|
||||
stream.async_set_message_stream(stream_message())
|
||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
assert result_data == b"hello"
|
||||
|
Loading…
x
Reference in New Issue
Block a user