diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 8182d375f96..22c388cae9f 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -3,8 +3,8 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncGenerator -from dataclasses import dataclass +from collections.abc import AsyncGenerator, MutableMapping +from dataclasses import dataclass, field from datetime import datetime import hashlib from http import HTTPStatus @@ -15,7 +15,7 @@ import os import re import secrets from time import monotonic -from typing import Any, Final +from typing import Any, Final, Generic, Protocol, TypeVar from aiohttp import web import mutagen @@ -60,10 +60,10 @@ from .const import ( DOMAIN, TtsAudioType, ) -from .entity import TextToSpeechEntity, TTSAudioRequest +from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse from .helper import get_engine_instance from .legacy import PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, Provider, async_setup_legacy -from .media_source import generate_media_source_id, media_source_id_to_kwargs +from .media_source import generate_media_source_id, parse_media_source_id from .models import Voice __all__ = [ @@ -79,6 +79,7 @@ __all__ = [ "Provider", "ResultStream", "SampleFormat", + "TTSAudioResponse", "TextToSpeechEntity", "TtsAudioType", "Voice", @@ -264,7 +265,7 @@ def async_create_stream( @callback def async_get_stream(hass: HomeAssistant, token: str) -> ResultStream | None: """Return a result stream given a token.""" - return hass.data[DATA_TTS_MANAGER].token_to_stream.get(token) + return hass.data[DATA_TTS_MANAGER].async_get_result_stream(token) async def async_get_media_source_audio( @@ -272,12 +273,11 @@ async def async_get_media_source_audio( media_source_id: str, ) -> tuple[str, bytes]: """Get TTS audio as extension, data.""" - manager = hass.data[DATA_TTS_MANAGER] - cache = manager.async_cache_message_in_memory( - **media_source_id_to_kwargs(media_source_id) - ) - data = b"".join([chunk async for chunk in cache.async_stream_data()]) - return cache.extension, data + parsed = parse_media_source_id(media_source_id) + stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"]) + stream.async_set_message(parsed["message"]) + data = b"".join([chunk async for chunk in stream.async_stream_result()]) + return stream.extension, data @callback @@ -457,6 +457,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: class ResultStream: """Class that will stream the result when available.""" + last_used: float = field(default_factory=monotonic, init=False) + # Streaming/conversion properties token: str extension: str @@ -480,11 +482,6 @@ class ResultStream: """Get the future that returns the cache.""" return asyncio.Future() - @callback - def async_set_message_cache(self, cache: TTSCache) -> None: - """Set cache containing message audio to be streamed.""" - self._result_cache.set_result(cache) - @callback def async_set_message(self, message: str) -> None: """Set message to be generated.""" @@ -504,6 +501,8 @@ class ResultStream: async for chunk in cache.async_stream_data(): yield chunk + self.last_used = monotonic() + def _hash_options(options: dict) -> str: """Hashes an options dictionary.""" @@ -515,13 +514,25 @@ def _hash_options(options: dict) -> str: return opts_hash.hexdigest() -class MemcacheCleanup: +class HasLastUsed(Protocol): + """Protocol for objects that have a last_used attribute.""" + + last_used: float + + +T = TypeVar("T", bound=HasLastUsed) + + +class DictCleaning(Generic[T]): """Helper to clean up the stale sessions.""" unsub: CALLBACK_TYPE | None = None def __init__( - self, hass: HomeAssistant, maxage: float, memcache: dict[str, TTSCache] + self, + hass: HomeAssistant, + maxage: float, + memcache: MutableMapping[str, T], ) -> None: """Initialize the cleanup.""" self.hass = hass @@ -588,8 +599,9 @@ class SpeechManager: self.file_cache: dict[str, str] = {} self.mem_cache: dict[str, TTSCache] = {} self.token_to_stream: dict[str, ResultStream] = {} - self.memcache_cleanup = MemcacheCleanup( - hass, memory_cache_maxage, self.mem_cache + self.memcache_cleanup = DictCleaning(hass, memory_cache_maxage, self.mem_cache) + self.token_to_stream_cleanup = DictCleaning( + hass, memory_cache_maxage, self.token_to_stream ) def _init_cache(self) -> dict[str, str]: @@ -679,11 +691,21 @@ class SpeechManager: return language, merged_options + @callback + def async_get_result_stream( + self, + token: str, + ) -> ResultStream | None: + """Return a result stream given a token.""" + stream = self.token_to_stream.get(token, None) + if stream: + stream.last_used = monotonic() + return stream + @callback def async_create_result_stream( self, engine: str, - message: str | None = None, use_file_cache: bool | None = None, language: str | None = None, options: dict | None = None, @@ -710,23 +732,7 @@ class SpeechManager: _manager=self, ) self.token_to_stream[token] = result_stream - - if message is None: - return result_stream - - # We added this method as an alternative to stream.async_set_message - # to avoid the options being processed twice - result_stream.async_set_message_cache( - self._async_ensure_cached_in_memory( - engine=engine, - engine_instance=engine_instance, - message=message, - use_file_cache=use_file_cache, - language=language, - options=options, - ) - ) - + self.token_to_stream_cleanup.schedule() return result_stream @callback @@ -734,41 +740,17 @@ class SpeechManager: self, engine: str, message: str, - use_file_cache: bool | None = None, - language: str | None = None, - options: dict | None = None, - ) -> TTSCache: - """Make sure a message is cached in memory and returns cache key.""" - if (engine_instance := get_engine_instance(self.hass, engine)) is None: - raise HomeAssistantError(f"Provider {engine} not found") - - language, options = self.process_options(engine_instance, language, options) - if use_file_cache is None: - use_file_cache = self.use_file_cache - - return self._async_ensure_cached_in_memory( - engine=engine, - engine_instance=engine_instance, - message=message, - use_file_cache=use_file_cache, - language=language, - options=options, - ) - - @callback - def _async_ensure_cached_in_memory( - self, - engine: str, - engine_instance: TextToSpeechEntity | Provider, - message: str, use_file_cache: bool, language: str, options: dict, ) -> TTSCache: - """Ensure a message is cached. + """Make sure a message will be cached in memory and returns cache object. Requires options, language to be processed. """ + if (engine_instance := get_engine_instance(self.hass, engine)) is None: + raise HomeAssistantError(f"Provider {engine} not found") + options_key = _hash_options(options) if options else "-" msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest() cache_key = KEY_PATTERN.format( @@ -789,9 +771,13 @@ class SpeechManager: store_to_disk = False else: _LOGGER.debug("Generating audio for %s", message[0:32]) + + async def message_stream() -> AsyncGenerator[str]: + yield message + extension = options.get(ATTR_PREFERRED_FORMAT, _DEFAULT_FORMAT) data_gen = self._async_generate_tts_audio( - engine_instance, message, language, options + engine_instance, message_stream(), language, options ) cache = TTSCache( @@ -799,7 +785,6 @@ class SpeechManager: extension=extension, data_gen=data_gen, ) - self.mem_cache[cache_key] = cache self.hass.async_create_background_task( self._load_data_into_cache( @@ -866,7 +851,7 @@ class SpeechManager: async def _async_generate_tts_audio( self, engine_instance: TextToSpeechEntity | Provider, - message: str, + message_stream: AsyncGenerator[str], language: str, options: dict[str, Any], ) -> AsyncGenerator[bytes]: @@ -915,6 +900,7 @@ class SpeechManager: raise HomeAssistantError("TTS engine name is not set.") if isinstance(engine_instance, Provider): + message = "".join([chunk async for chunk in message_stream]) extension, data = await engine_instance.async_get_tts_audio( message, language, options ) @@ -930,12 +916,8 @@ class SpeechManager: data_gen = make_data_generator(data) else: - - async def message_gen() -> AsyncGenerator[str]: - yield message - tts_result = await engine_instance.internal_async_stream_tts_audio( - TTSAudioRequest(language, options, message_gen()) + TTSAudioRequest(language, options, message_stream) ) extension = tts_result.extension data_gen = tts_result.data_gen @@ -1096,7 +1078,6 @@ class TextToSpeechUrlView(HomeAssistantView): try: stream = self.manager.async_create_result_stream( engine, - message, use_file_cache=use_file_cache, language=language, options=options, @@ -1105,6 +1086,8 @@ class TextToSpeechUrlView(HomeAssistantView): _LOGGER.error("Error on init tts: %s", err) return self.json({"error": err}, HTTPStatus.BAD_REQUEST) + stream.async_set_message(message) + base = get_url(self.manager.hass) url = base + stream.url diff --git a/homeassistant/components/tts/media_source.py b/homeassistant/components/tts/media_source.py index aa2cd6e7555..97d2ab549bc 100644 --- a/homeassistant/components/tts/media_source.py +++ b/homeassistant/components/tts/media_source.py @@ -69,14 +69,20 @@ class MediaSourceOptions(TypedDict): """Media source options.""" engine: str - message: str language: str | None options: dict | None use_file_cache: bool | None +class ParsedMediaSourceId(TypedDict): + """Parsed media source ID.""" + + options: MediaSourceOptions + message: str + + @callback -def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions: +def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId: """Turn a media source ID into options.""" parsed = URL(media_source_id) if URL_QUERY_TTS_OPTIONS in parsed.query: @@ -94,7 +100,6 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions: raise Unresolvable("No message specified.") kwargs: MediaSourceOptions = { "engine": parsed.name, - "message": parsed.query["message"], "language": parsed.query.get("language"), "options": options, "use_file_cache": None, @@ -102,7 +107,7 @@ def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions: if "cache" in parsed.query: kwargs["use_file_cache"] = parsed.query["cache"] == "true" - return kwargs + return {"message": parsed.query["message"], "options": kwargs} class TTSMediaSource(MediaSource): @@ -118,9 +123,11 @@ class TTSMediaSource(MediaSource): async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: """Resolve media to a url.""" try: + parsed = parse_media_source_id(item.identifier) stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream( - **media_source_id_to_kwargs(item.identifier) + **parsed["options"] ) + stream.async_set_message(parsed["message"]) except Unresolvable: raise except HomeAssistantError as err: diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index 99c698771f7..c21db66dfac 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -42,6 +42,7 @@ from tests.typing import ClientSessionGenerator DEFAULT_LANG = "en_US" SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"] TEST_DOMAIN = "test" +MOCK_DATA = b"123" def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]: @@ -164,7 +165,7 @@ class BaseProvider: self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: """Load TTS dat.""" - return ("mp3", b"") + return ("mp3", MOCK_DATA) class MockTTSProvider(BaseProvider, Provider): diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 4e17bc68a5e..99f4b008c68 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -27,6 +27,7 @@ from homeassistant.util import dt as dt_util from .common import ( DEFAULT_LANG, + MOCK_DATA, TEST_DOMAIN, MockResultStream, MockTTS, @@ -808,7 +809,7 @@ async def test_service_receive_voice( await hass.async_block_till_done() client = await hass_client() req = await client.get(url) - tts_data = b"" + tts_data = MOCK_DATA tts_data = tts.SpeechManager.write_tags( f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3", tts_data, @@ -879,7 +880,7 @@ async def test_service_receive_voice_german( await hass.async_block_till_done() client = await hass_client() req = await client.get(url) - tts_data = b"" + tts_data = MOCK_DATA tts_data = tts.SpeechManager.write_tags( "42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3", tts_data, @@ -1021,7 +1022,7 @@ async def test_setup_legacy_cache_dir( """Set up a TTS platform with cache and call service without cache.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) - tts_data = b"" + tts_data = MOCK_DATA cache_file = ( mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ) @@ -1059,7 +1060,7 @@ async def test_setup_cache_dir( """Set up a TTS platform with cache and call service without cache.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) - tts_data = b"" + tts_data = MOCK_DATA cache_file = mock_tts_cache_dir / ( "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" ) @@ -1165,7 +1166,7 @@ async def test_legacy_cannot_retrieve_without_token( hass_client: ClientSessionGenerator, ) -> None: """Verify that a TTS cannot be retrieved by filename directly.""" - tts_data = b"" + tts_data = MOCK_DATA cache_file = ( mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" ) @@ -1188,7 +1189,7 @@ async def test_cannot_retrieve_without_token( hass_client: ClientSessionGenerator, ) -> None: """Verify that a TTS cannot be retrieved by filename directly.""" - tts_data = b"" + tts_data = MOCK_DATA cache_file = mock_tts_cache_dir / ( "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" ) @@ -1845,6 +1846,9 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No assert stream.language == mock_tts_entity.default_language assert stream.options == (mock_tts_entity.default_options or {}) assert tts.async_get_stream(hass, stream.token) is stream + stream.async_set_message("beer") + result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) + assert result_data == MOCK_DATA data = b"beer" stream2 = MockResultStream(hass, "wav", data) diff --git a/tests/components/tts/test_media_source.py b/tests/components/tts/test_media_source.py index 9e50cc6b512..4ff0a44a4bb 100644 --- a/tests/components/tts/test_media_source.py +++ b/tests/components/tts/test_media_source.py @@ -9,9 +9,8 @@ import pytest from homeassistant.components import media_source from homeassistant.components.media_player import BrowseError from homeassistant.components.tts.media_source import ( - MediaSourceOptions, generate_media_source_id, - media_source_id_to_kwargs, + parse_media_source_id, ) from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -249,13 +248,13 @@ async def test_resolving_errors(hass: HomeAssistant, setup: str, engine: str) -> ], indirect=["setup"], ) -async def test_generate_media_source_id_and_media_source_id_to_kwargs( +async def test_generate_media_source_id_and_parse_media_source_id( hass: HomeAssistant, setup: str, result_engine: str, ) -> None: - """Test media_source_id and media_source_id_to_kwargs.""" - kwargs: MediaSourceOptions = { + """Test media_source_id and parse_media_source_id.""" + kwargs = { "engine": None, "message": "hello", "language": "en_US", @@ -263,12 +262,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs( "cache": True, } media_source_id = generate_media_source_id(hass, **kwargs) - assert media_source_id_to_kwargs(media_source_id) == { - "engine": result_engine, + assert parse_media_source_id(media_source_id) == { "message": "hello", - "language": "en_US", - "options": {"age": 5}, - "use_file_cache": True, + "options": { + "engine": result_engine, + "language": "en_US", + "options": {"age": 5}, + "use_file_cache": True, + }, } kwargs = { @@ -279,12 +280,14 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs( "cache": True, } media_source_id = generate_media_source_id(hass, **kwargs) - assert media_source_id_to_kwargs(media_source_id) == { - "engine": result_engine, + assert parse_media_source_id(media_source_id) == { "message": "hello", - "language": "en_US", - "options": {"age": [5, 6]}, - "use_file_cache": True, + "options": { + "engine": result_engine, + "language": "en_US", + "options": {"age": [5, 6]}, + "use_file_cache": True, + }, } kwargs = { @@ -295,10 +298,12 @@ async def test_generate_media_source_id_and_media_source_id_to_kwargs( "cache": True, } media_source_id = generate_media_source_id(hass, **kwargs) - assert media_source_id_to_kwargs(media_source_id) == { - "engine": result_engine, + assert parse_media_source_id(media_source_id) == { "message": "hello", - "language": "en_US", - "options": {"age": {"k1": [5, 6], "k2": "v2"}}, - "use_file_cache": True, + "options": { + "engine": result_engine, + "language": "en_US", + "options": {"age": {"k1": [5, 6], "k2": "v2"}}, + "use_file_cache": True, + }, }