mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 08:07:45 +00:00
Track if TTS entity supports streaming input (#144697)
* Track if entity supports streaming * Make class method
This commit is contained in:
parent
c022c32d2f
commit
4994229215
@ -469,6 +469,7 @@ class ResultStream:
|
|||||||
use_file_cache: bool
|
use_file_cache: bool
|
||||||
language: str
|
language: str
|
||||||
options: dict
|
options: dict
|
||||||
|
supports_streaming_input: bool
|
||||||
|
|
||||||
_manager: SpeechManager
|
_manager: SpeechManager
|
||||||
|
|
||||||
@ -484,7 +485,10 @@ class ResultStream:
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_message(self, message: str) -> None:
|
def async_set_message(self, message: str) -> None:
|
||||||
"""Set message to be generated."""
|
"""Set message to be generated.
|
||||||
|
|
||||||
|
This method will leverage a disk cache to speed up generation.
|
||||||
|
"""
|
||||||
self._result_cache.set_result(
|
self._result_cache.set_result(
|
||||||
self._manager.async_cache_message_in_memory(
|
self._manager.async_cache_message_in_memory(
|
||||||
engine=self.engine,
|
engine=self.engine,
|
||||||
@ -497,7 +501,10 @@ class ResultStream:
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None:
|
def async_set_message_stream(self, message_stream: AsyncGenerator[str]) -> None:
|
||||||
"""Set a stream that will generate the message."""
|
"""Set a stream that will generate the message.
|
||||||
|
|
||||||
|
This method can result in faster first byte when generating long responses.
|
||||||
|
"""
|
||||||
self._result_cache.set_result(
|
self._result_cache.set_result(
|
||||||
self._manager.async_cache_message_stream_in_memory(
|
self._manager.async_cache_message_stream_in_memory(
|
||||||
engine=self.engine,
|
engine=self.engine,
|
||||||
@ -726,6 +733,10 @@ class SpeechManager:
|
|||||||
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
if (engine_instance := get_engine_instance(self.hass, engine)) is None:
|
||||||
raise HomeAssistantError(f"Provider {engine} not found")
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
|
|
||||||
|
supports_streaming_input = (
|
||||||
|
isinstance(engine_instance, TextToSpeechEntity)
|
||||||
|
and engine_instance.async_supports_streaming_input()
|
||||||
|
)
|
||||||
language, options = self.process_options(engine_instance, language, options)
|
language, options = self.process_options(engine_instance, language, options)
|
||||||
if use_file_cache is None:
|
if use_file_cache is None:
|
||||||
use_file_cache = self.use_file_cache
|
use_file_cache = self.use_file_cache
|
||||||
@ -741,6 +752,7 @@ class SpeechManager:
|
|||||||
engine=engine,
|
engine=engine,
|
||||||
language=language,
|
language=language,
|
||||||
options=options,
|
options=options,
|
||||||
|
supports_streaming_input=supports_streaming_input,
|
||||||
_manager=self,
|
_manager=self,
|
||||||
)
|
)
|
||||||
self.token_to_stream[token] = result_stream
|
self.token_to_stream[token] = result_stream
|
||||||
|
@ -89,6 +89,13 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
|||||||
"""Return a mapping with the default options."""
|
"""Return a mapping with the default options."""
|
||||||
return self._attr_default_options
|
return self._attr_default_options
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def async_supports_streaming_input(cls) -> bool:
|
||||||
|
"""Return if the TTS engine supports streaming input."""
|
||||||
|
return (
|
||||||
|
cls.async_stream_tts_audio is not TextToSpeechEntity.async_stream_tts_audio
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||||
"""Return a list of supported voices for a language."""
|
"""Return a list of supported voices for a language."""
|
||||||
|
@ -281,6 +281,7 @@ class MockResultStream(ResultStream):
|
|||||||
content_type=f"audio/mock-{extension}",
|
content_type=f"audio/mock-{extension}",
|
||||||
engine="test-engine",
|
engine="test-engine",
|
||||||
use_file_cache=True,
|
use_file_cache=True,
|
||||||
|
supports_streaming_input=True,
|
||||||
language="en",
|
language="en",
|
||||||
options={},
|
options={},
|
||||||
_manager=hass.data[DATA_TTS_MANAGER],
|
_manager=hass.data[DATA_TTS_MANAGER],
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Tests for the TTS entity."""
|
"""Tests for the TTS entity."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import tts
|
from homeassistant.components import tts
|
||||||
@ -142,3 +144,34 @@ async def test_tts_entity_subclass_properties(
|
|||||||
if record.exc_info is not None
|
if record.exc_info is not None
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_supported() -> None:
|
||||||
|
"""Test streaming support."""
|
||||||
|
base_entity = tts.TextToSpeechEntity()
|
||||||
|
assert base_entity.async_supports_streaming_input() is False
|
||||||
|
|
||||||
|
class StreamingEntity(tts.TextToSpeechEntity):
|
||||||
|
async def async_stream_tts_audio(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
streaming_entity = StreamingEntity()
|
||||||
|
assert streaming_entity.async_supports_streaming_input() is True
|
||||||
|
|
||||||
|
class NonStreamingEntity(tts.TextToSpeechEntity):
|
||||||
|
async def async_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
|
) -> tts.TtsAudioType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
non_streaming_entity = NonStreamingEntity()
|
||||||
|
assert non_streaming_entity.async_supports_streaming_input() is False
|
||||||
|
|
||||||
|
class SyncNonStreamingEntity(tts.TextToSpeechEntity):
|
||||||
|
def get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
|
) -> tts.TtsAudioType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
sync_non_streaming_entity = SyncNonStreamingEntity()
|
||||||
|
assert sync_non_streaming_entity.async_supports_streaming_input() is False
|
||||||
|
@ -4,7 +4,7 @@ import asyncio
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
import pytest
|
import pytest
|
||||||
@ -1885,6 +1885,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
|||||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||||
assert stream.language == mock_tts_entity.default_language
|
assert stream.language == mock_tts_entity.default_language
|
||||||
assert stream.options == (mock_tts_entity.default_options or {})
|
assert stream.options == (mock_tts_entity.default_options or {})
|
||||||
|
assert stream.supports_streaming_input is False
|
||||||
assert tts.async_get_stream(hass, stream.token) is stream
|
assert tts.async_get_stream(hass, stream.token) is stream
|
||||||
stream.async_set_message("beer")
|
stream.async_set_message("beer")
|
||||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
@ -1905,6 +1906,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio
|
||||||
|
mock_tts_entity.async_supports_streaming_input = Mock(return_value=True)
|
||||||
|
|
||||||
async def stream_message():
|
async def stream_message():
|
||||||
"""Mock stream message."""
|
"""Mock stream message."""
|
||||||
@ -1913,6 +1915,7 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No
|
|||||||
yield "o"
|
yield "o"
|
||||||
|
|
||||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||||
|
assert stream.supports_streaming_input is True
|
||||||
stream.async_set_message_stream(stream_message())
|
stream.async_set_message_stream(stream_message())
|
||||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||||
assert result_data == b"hello"
|
assert result_data == b"hello"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user