From 86e9f6643f3328232c0bf2206bbfcc28f0e3728a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 6 Apr 2023 11:42:55 -0400 Subject: [PATCH] Allow TTS requests to resolve in the background (#90944) --- homeassistant/components/cloud/tts.py | 30 +++++--- homeassistant/components/tts/__init__.py | 89 ++++++++++++++++++------ tests/components/cloud/test_tts.py | 12 ++-- tests/components/tts/test_init.py | 71 +++++++++++++++++++ 4 files changed, 163 insertions(+), 39 deletions(-) diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index bbf4ef287d6..a10b0c98cd8 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -4,11 +4,16 @@ from hass_nabucasa import Cloud from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError import voluptuous as vol -from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider +from homeassistant.components.tts import ( + ATTR_AUDIO_OUTPUT, + CONF_LANG, + PLATFORM_SCHEMA, + Provider, +) from .const import DOMAIN -CONF_GENDER = "gender" +ATTR_GENDER = "gender" SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE}) @@ -18,8 +23,8 @@ def validate_lang(value): if (lang := value.get(CONF_LANG)) is None: return value - if (gender := value.get(CONF_GENDER)) is None: - gender = value[CONF_GENDER] = next( + if (gender := value.get(ATTR_GENDER)) is None: + gender = value[ATTR_GENDER] = next( (chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None ) @@ -33,7 +38,7 @@ PLATFORM_SCHEMA = vol.All( PLATFORM_SCHEMA.extend( { vol.Optional(CONF_LANG): str, - vol.Optional(CONF_GENDER): str, + vol.Optional(ATTR_GENDER): str, } ), validate_lang, @@ -49,7 +54,7 @@ async def async_get_engine(hass, config, discovery_info=None): gender = None else: language = config[CONF_LANG] - gender = config[CONF_GENDER] + gender = config[ATTR_GENDER] return CloudProvider(cloud, language, gender) @@ -87,12 +92,15 @@ class CloudProvider(Provider): @property def supported_options(self): """Return list of supported options like voice, emotion.""" - return [CONF_GENDER] + return [ATTR_GENDER, ATTR_AUDIO_OUTPUT] @property def default_options(self): """Return a dict include default options.""" - return {CONF_GENDER: self._gender} + return { + ATTR_GENDER: self._gender, + ATTR_AUDIO_OUTPUT: AudioOutput.MP3, + } async def async_get_tts_audio(self, message, language, options=None): """Load TTS from NabuCasa Cloud.""" @@ -101,10 +109,10 @@ class CloudProvider(Provider): data = await self.cloud.voice.process_tts( message, language, - gender=options[CONF_GENDER], - output=AudioOutput.MP3, + gender=options[ATTR_GENDER], + output=options[ATTR_AUDIO_OUTPUT], ) except VoiceError: return (None, None) - return ("mp3", data) + return (str(options[ATTR_AUDIO_OUTPUT]), data) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 119a013ebf6..c1a827c27bb 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -59,6 +59,7 @@ ATTR_LANGUAGE = "language" ATTR_MESSAGE = "message" ATTR_OPTIONS = "options" ATTR_PLATFORM = "platform" +ATTR_AUDIO_OUTPUT = "audio_output" BASE_URL_KEY = "tts_base_url" @@ -134,6 +135,7 @@ class TTSCache(TypedDict): filename: str voice: bytes + pending: asyncio.Task | None @callback @@ -495,8 +497,11 @@ class SpeechManager: ) extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:] - data = self.mem_cache[cache_key]["voice"] - return extension, data + cached = self.mem_cache[cache_key] + if pending := cached.get("pending"): + await pending + cached = self.mem_cache[cache_key] + return extension, cached["voice"] @callback def _generate_cache_key( @@ -527,30 +532,62 @@ class SpeechManager: This method is a coroutine. """ provider = self.providers[engine] - extension, data = await provider.async_get_tts_audio(message, language, options) - if data is None or extension is None: - raise HomeAssistantError(f"No TTS from {engine} for '{message}'") + if options is not None and ATTR_AUDIO_OUTPUT in options: + expected_extension = options[ATTR_AUDIO_OUTPUT] + else: + expected_extension = None - # Create file infos - filename = f"{cache_key}.{extension}".lower() - - # Validate filename - if not _RE_VOICE_FILE.match(filename): - raise HomeAssistantError( - f"TTS filename '{filename}' from {engine} is invalid!" + async def get_tts_data() -> str: + """Handle data available.""" + extension, data = await provider.async_get_tts_audio( + message, language, options ) - # Save to memory - if extension == "mp3": - data = self.write_tags(filename, data, provider, message, language, options) - self._async_store_to_memcache(cache_key, filename, data) + if data is None or extension is None: + raise HomeAssistantError(f"No TTS from {engine} for '{message}'") - if cache: - self.hass.async_create_task( - self._async_save_tts_audio(cache_key, filename, data) - ) + # Create file infos + filename = f"{cache_key}.{extension}".lower() + # Validate filename + if not _RE_VOICE_FILE.match(filename): + raise HomeAssistantError( + f"TTS filename '{filename}' from {engine} is invalid!" + ) + + # Save to memory + if extension == "mp3": + data = self.write_tags( + filename, data, provider, message, language, options + ) + self._async_store_to_memcache(cache_key, filename, data) + + if cache: + self.hass.async_create_task( + self._async_save_tts_audio(cache_key, filename, data) + ) + + return filename + + audio_task = self.hass.async_create_task(get_tts_data()) + + if expected_extension is None: + return await audio_task + + def handle_error(_future: asyncio.Future) -> None: + """Handle error.""" + if audio_task.exception(): + self.mem_cache.pop(cache_key, None) + + audio_task.add_done_callback(handle_error) + + filename = f"{cache_key}.{expected_extension}".lower() + self.mem_cache[cache_key] = { + "filename": filename, + "voice": b"", + "pending": audio_task, + } return filename async def _async_save_tts_audio( @@ -601,7 +638,11 @@ class SpeechManager: self, cache_key: str, filename: str, data: bytes ) -> None: """Store data to memcache and set timer to remove it.""" - self.mem_cache[cache_key] = {"filename": filename, "voice": data} + self.mem_cache[cache_key] = { + "filename": filename, + "voice": data, + "pending": None, + } @callback def async_remove_from_mem() -> None: @@ -628,7 +669,11 @@ class SpeechManager: await self._async_file_to_mem(cache_key) content, _ = mimetypes.guess_type(filename) - return content, self.mem_cache[cache_key]["voice"] + cached = self.mem_cache[cache_key] + if pending := cached.get("pending"): + await pending + cached = self.mem_cache[cache_key] + return content, cached["voice"] @staticmethod def write_tags( diff --git a/tests/components/cloud/test_tts.py b/tests/components/cloud/test_tts.py index a17b0ae2f08..4d2ac35d56d 100644 --- a/tests/components/cloud/test_tts.py +++ b/tests/components/cloud/test_tts.py @@ -58,17 +58,17 @@ async def test_prefs_default_voice( ) assert provider_pref.default_language == "en-US" - assert provider_pref.default_options == {"gender": "female"} + assert provider_pref.default_options == {"gender": "female", "audio_output": "mp3"} assert provider_conf.default_language == "fr-FR" - assert provider_conf.default_options == {"gender": "female"} + assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"} await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male")) await hass.async_block_till_done() assert provider_pref.default_language == "nl-NL" - assert provider_pref.default_options == {"gender": "male"} + assert provider_pref.default_options == {"gender": "male", "audio_output": "mp3"} assert provider_conf.default_language == "fr-FR" - assert provider_conf.default_options == {"gender": "female"} + assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"} async def test_provider_properties(cloud_with_prefs) -> None: @@ -76,7 +76,7 @@ async def test_provider_properties(cloud_with_prefs) -> None: provider = await tts.async_get_engine( Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} ) - assert provider.supported_options == ["gender"] + assert provider.supported_options == ["gender", "audio_output"] assert "nl-NL" in provider.supported_languages @@ -85,5 +85,5 @@ async def test_get_tts_audio(cloud_with_prefs) -> None: provider = await tts.async_get_engine( Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} ) - assert provider.supported_options == ["gender"] + assert provider.supported_options == ["gender", "audio_output"] assert "nl-NL" in provider.supported_languages diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 694c9ff676c..b6004c13d46 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1,4 +1,5 @@ """The tests for the TTS component.""" +import asyncio from http import HTTPStatus from typing import Any from unittest.mock import patch @@ -996,3 +997,73 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None: await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"}) is False ) + + +async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None: + """Test async fetching of data.""" + tts_audio = asyncio.Future() + + class ProviderWithAsyncFetching(MockProvider): + """Provider that supports audio output option.""" + + @property + def supported_options(self) -> list[str]: + """Return list of supported options like voice, emotions.""" + return [tts.ATTR_AUDIO_OUTPUT] + + @property + def default_options(self) -> dict[str, str]: + """Return a dict including the default options.""" + return {tts.ATTR_AUDIO_OUTPUT: "mp3"} + + async def async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] | None = None + ) -> tts.TtsAudioType: + return ("mp3", await tts_audio) + + mock_integration(hass, MockModule(domain="test")) + mock_platform(hass, "test.tts", MockTTS(ProviderWithAsyncFetching)) + assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) + + # Test async_get_media_source_audio + media_source_id = tts.generate_media_source_id( + hass, "test message", "test", "en", None, None + ) + + task = hass.async_create_task( + tts.async_get_media_source_audio(hass, media_source_id) + ) + task2 = hass.async_create_task( + tts.async_get_media_source_audio(hass, media_source_id) + ) + + url = await get_media_source_url(hass, media_source_id) + client = await hass_client() + client_get_task = hass.async_create_task(client.get(url)) + + # Make sure that tasks are waiting for our future to resolve + done, pending = await asyncio.wait((task, task2, client_get_task), timeout=0.1) + assert len(done) == 0 + assert len(pending) == 3 + + tts_audio.set_result(b"test") + + assert await task == ("mp3", b"test") + assert await task2 == ("mp3", b"test") + + req = await client_get_task + assert req.status == HTTPStatus.OK + assert await req.read() == b"test" + + # Test error is not cached + media_source_id = tts.generate_media_source_id( + hass, "test message 2", "test", "en", None, None + ) + tts_audio = asyncio.Future() + tts_audio.set_exception(HomeAssistantError("test error")) + with pytest.raises(HomeAssistantError): + assert await tts.async_get_media_source_audio(hass, media_source_id) + + tts_audio = asyncio.Future() + tts_audio.set_result(b"test 2") + await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")