Allow TTS requests to resolve in the background (#90944)

This commit is contained in:
Paulus Schoutsen 2023-04-06 11:42:55 -04:00 committed by GitHub
parent 59a02cd08c
commit 86e9f6643f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 39 deletions

View File

@ -4,11 +4,16 @@ from hass_nabucasa import Cloud
from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError
import voluptuous as vol import voluptuous as vol
from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider from homeassistant.components.tts import (
ATTR_AUDIO_OUTPUT,
CONF_LANG,
PLATFORM_SCHEMA,
Provider,
)
from .const import DOMAIN from .const import DOMAIN
CONF_GENDER = "gender" ATTR_GENDER = "gender"
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE}) SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
@ -18,8 +23,8 @@ def validate_lang(value):
if (lang := value.get(CONF_LANG)) is None: if (lang := value.get(CONF_LANG)) is None:
return value return value
if (gender := value.get(CONF_GENDER)) is None: if (gender := value.get(ATTR_GENDER)) is None:
gender = value[CONF_GENDER] = next( gender = value[ATTR_GENDER] = next(
(chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None (chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None
) )
@ -33,7 +38,7 @@ PLATFORM_SCHEMA = vol.All(
PLATFORM_SCHEMA.extend( PLATFORM_SCHEMA.extend(
{ {
vol.Optional(CONF_LANG): str, vol.Optional(CONF_LANG): str,
vol.Optional(CONF_GENDER): str, vol.Optional(ATTR_GENDER): str,
} }
), ),
validate_lang, validate_lang,
@ -49,7 +54,7 @@ async def async_get_engine(hass, config, discovery_info=None):
gender = None gender = None
else: else:
language = config[CONF_LANG] language = config[CONF_LANG]
gender = config[CONF_GENDER] gender = config[ATTR_GENDER]
return CloudProvider(cloud, language, gender) return CloudProvider(cloud, language, gender)
@ -87,12 +92,15 @@ class CloudProvider(Provider):
@property @property
def supported_options(self): def supported_options(self):
"""Return list of supported options like voice, emotion.""" """Return list of supported options like voice, emotion."""
return [CONF_GENDER] return [ATTR_GENDER, ATTR_AUDIO_OUTPUT]
@property @property
def default_options(self): def default_options(self):
"""Return a dict include default options.""" """Return a dict include default options."""
return {CONF_GENDER: self._gender} return {
ATTR_GENDER: self._gender,
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
}
async def async_get_tts_audio(self, message, language, options=None): async def async_get_tts_audio(self, message, language, options=None):
"""Load TTS from NabuCasa Cloud.""" """Load TTS from NabuCasa Cloud."""
@ -101,10 +109,10 @@ class CloudProvider(Provider):
data = await self.cloud.voice.process_tts( data = await self.cloud.voice.process_tts(
message, message,
language, language,
gender=options[CONF_GENDER], gender=options[ATTR_GENDER],
output=AudioOutput.MP3, output=options[ATTR_AUDIO_OUTPUT],
) )
except VoiceError: except VoiceError:
return (None, None) return (None, None)
return ("mp3", data) return (str(options[ATTR_AUDIO_OUTPUT]), data)

View File

@ -59,6 +59,7 @@ ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message" ATTR_MESSAGE = "message"
ATTR_OPTIONS = "options" ATTR_OPTIONS = "options"
ATTR_PLATFORM = "platform" ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
BASE_URL_KEY = "tts_base_url" BASE_URL_KEY = "tts_base_url"
@ -134,6 +135,7 @@ class TTSCache(TypedDict):
filename: str filename: str
voice: bytes voice: bytes
pending: asyncio.Task | None
@callback @callback
@ -495,8 +497,11 @@ class SpeechManager:
) )
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:] extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
data = self.mem_cache[cache_key]["voice"] cached = self.mem_cache[cache_key]
return extension, data if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
return extension, cached["voice"]
@callback @callback
def _generate_cache_key( def _generate_cache_key(
@ -527,7 +532,17 @@ class SpeechManager:
This method is a coroutine. This method is a coroutine.
""" """
provider = self.providers[engine] provider = self.providers[engine]
extension, data = await provider.async_get_tts_audio(message, language, options)
if options is not None and ATTR_AUDIO_OUTPUT in options:
expected_extension = options[ATTR_AUDIO_OUTPUT]
else:
expected_extension = None
async def get_tts_data() -> str:
"""Handle data available."""
extension, data = await provider.async_get_tts_audio(
message, language, options
)
if data is None or extension is None: if data is None or extension is None:
raise HomeAssistantError(f"No TTS from {engine} for '{message}'") raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
@ -543,7 +558,9 @@ class SpeechManager:
# Save to memory # Save to memory
if extension == "mp3": if extension == "mp3":
data = self.write_tags(filename, data, provider, message, language, options) data = self.write_tags(
filename, data, provider, message, language, options
)
self._async_store_to_memcache(cache_key, filename, data) self._async_store_to_memcache(cache_key, filename, data)
if cache: if cache:
@ -553,6 +570,26 @@ class SpeechManager:
return filename return filename
audio_task = self.hass.async_create_task(get_tts_data())
if expected_extension is None:
return await audio_task
def handle_error(_future: asyncio.Future) -> None:
"""Handle error."""
if audio_task.exception():
self.mem_cache.pop(cache_key, None)
audio_task.add_done_callback(handle_error)
filename = f"{cache_key}.{expected_extension}".lower()
self.mem_cache[cache_key] = {
"filename": filename,
"voice": b"",
"pending": audio_task,
}
return filename
async def _async_save_tts_audio( async def _async_save_tts_audio(
self, cache_key: str, filename: str, data: bytes self, cache_key: str, filename: str, data: bytes
) -> None: ) -> None:
@ -601,7 +638,11 @@ class SpeechManager:
self, cache_key: str, filename: str, data: bytes self, cache_key: str, filename: str, data: bytes
) -> None: ) -> None:
"""Store data to memcache and set timer to remove it.""" """Store data to memcache and set timer to remove it."""
self.mem_cache[cache_key] = {"filename": filename, "voice": data} self.mem_cache[cache_key] = {
"filename": filename,
"voice": data,
"pending": None,
}
@callback @callback
def async_remove_from_mem() -> None: def async_remove_from_mem() -> None:
@ -628,7 +669,11 @@ class SpeechManager:
await self._async_file_to_mem(cache_key) await self._async_file_to_mem(cache_key)
content, _ = mimetypes.guess_type(filename) content, _ = mimetypes.guess_type(filename)
return content, self.mem_cache[cache_key]["voice"] cached = self.mem_cache[cache_key]
if pending := cached.get("pending"):
await pending
cached = self.mem_cache[cache_key]
return content, cached["voice"]
@staticmethod @staticmethod
def write_tags( def write_tags(

View File

@ -58,17 +58,17 @@ async def test_prefs_default_voice(
) )
assert provider_pref.default_language == "en-US" assert provider_pref.default_language == "en-US"
assert provider_pref.default_options == {"gender": "female"} assert provider_pref.default_options == {"gender": "female", "audio_output": "mp3"}
assert provider_conf.default_language == "fr-FR" assert provider_conf.default_language == "fr-FR"
assert provider_conf.default_options == {"gender": "female"} assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male")) await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
await hass.async_block_till_done() await hass.async_block_till_done()
assert provider_pref.default_language == "nl-NL" assert provider_pref.default_language == "nl-NL"
assert provider_pref.default_options == {"gender": "male"} assert provider_pref.default_options == {"gender": "male", "audio_output": "mp3"}
assert provider_conf.default_language == "fr-FR" assert provider_conf.default_language == "fr-FR"
assert provider_conf.default_options == {"gender": "female"} assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
async def test_provider_properties(cloud_with_prefs) -> None: async def test_provider_properties(cloud_with_prefs) -> None:
@ -76,7 +76,7 @@ async def test_provider_properties(cloud_with_prefs) -> None:
provider = await tts.async_get_engine( provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
) )
assert provider.supported_options == ["gender"] assert provider.supported_options == ["gender", "audio_output"]
assert "nl-NL" in provider.supported_languages assert "nl-NL" in provider.supported_languages
@ -85,5 +85,5 @@ async def test_get_tts_audio(cloud_with_prefs) -> None:
provider = await tts.async_get_engine( provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
) )
assert provider.supported_options == ["gender"] assert provider.supported_options == ["gender", "audio_output"]
assert "nl-NL" in provider.supported_languages assert "nl-NL" in provider.supported_languages

View File

@ -1,4 +1,5 @@
"""The tests for the TTS component.""" """The tests for the TTS component."""
import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
@ -996,3 +997,73 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"}) await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"})
is False is False
) )
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
"""Test async fetching of data."""
tts_audio = asyncio.Future()
class ProviderWithAsyncFetching(MockProvider):
"""Provider that supports audio output option."""
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return [tts.ATTR_AUDIO_OUTPUT]
@property
def default_options(self) -> dict[str, str]:
"""Return a dict including the default options."""
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
return ("mp3", await tts_audio)
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS(ProviderWithAsyncFetching))
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
# Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id(
hass, "test message", "test", "en", None, None
)
task = hass.async_create_task(
tts.async_get_media_source_audio(hass, media_source_id)
)
task2 = hass.async_create_task(
tts.async_get_media_source_audio(hass, media_source_id)
)
url = await get_media_source_url(hass, media_source_id)
client = await hass_client()
client_get_task = hass.async_create_task(client.get(url))
# Make sure that tasks are waiting for our future to resolve
done, pending = await asyncio.wait((task, task2, client_get_task), timeout=0.1)
assert len(done) == 0
assert len(pending) == 3
tts_audio.set_result(b"test")
assert await task == ("mp3", b"test")
assert await task2 == ("mp3", b"test")
req = await client_get_task
assert req.status == HTTPStatus.OK
assert await req.read() == b"test"
# Test error is not cached
media_source_id = tts.generate_media_source_id(
hass, "test message 2", "test", "en", None, None
)
tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error"))
with pytest.raises(HomeAssistantError):
assert await tts.async_get_media_source_audio(hass, media_source_id)
tts_audio = asyncio.Future()
tts_audio.set_result(b"test 2")
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")