mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Allow TTS requests to resolve in the background (#90944)
This commit is contained in:
parent
59a02cd08c
commit
86e9f6643f
@ -4,11 +4,16 @@ from hass_nabucasa import Cloud
|
|||||||
from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError
|
from hass_nabucasa.voice import MAP_VOICE, AudioOutput, VoiceError
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.tts import CONF_LANG, PLATFORM_SCHEMA, Provider
|
from homeassistant.components.tts import (
|
||||||
|
ATTR_AUDIO_OUTPUT,
|
||||||
|
CONF_LANG,
|
||||||
|
PLATFORM_SCHEMA,
|
||||||
|
Provider,
|
||||||
|
)
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
|
||||||
CONF_GENDER = "gender"
|
ATTR_GENDER = "gender"
|
||||||
|
|
||||||
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
|
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
|
||||||
|
|
||||||
@ -18,8 +23,8 @@ def validate_lang(value):
|
|||||||
if (lang := value.get(CONF_LANG)) is None:
|
if (lang := value.get(CONF_LANG)) is None:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
if (gender := value.get(CONF_GENDER)) is None:
|
if (gender := value.get(ATTR_GENDER)) is None:
|
||||||
gender = value[CONF_GENDER] = next(
|
gender = value[ATTR_GENDER] = next(
|
||||||
(chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None
|
(chk_gender for chk_lang, chk_gender in MAP_VOICE if chk_lang == lang), None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +38,7 @@ PLATFORM_SCHEMA = vol.All(
|
|||||||
PLATFORM_SCHEMA.extend(
|
PLATFORM_SCHEMA.extend(
|
||||||
{
|
{
|
||||||
vol.Optional(CONF_LANG): str,
|
vol.Optional(CONF_LANG): str,
|
||||||
vol.Optional(CONF_GENDER): str,
|
vol.Optional(ATTR_GENDER): str,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
validate_lang,
|
validate_lang,
|
||||||
@ -49,7 +54,7 @@ async def async_get_engine(hass, config, discovery_info=None):
|
|||||||
gender = None
|
gender = None
|
||||||
else:
|
else:
|
||||||
language = config[CONF_LANG]
|
language = config[CONF_LANG]
|
||||||
gender = config[CONF_GENDER]
|
gender = config[ATTR_GENDER]
|
||||||
|
|
||||||
return CloudProvider(cloud, language, gender)
|
return CloudProvider(cloud, language, gender)
|
||||||
|
|
||||||
@ -87,12 +92,15 @@ class CloudProvider(Provider):
|
|||||||
@property
|
@property
|
||||||
def supported_options(self):
|
def supported_options(self):
|
||||||
"""Return list of supported options like voice, emotion."""
|
"""Return list of supported options like voice, emotion."""
|
||||||
return [CONF_GENDER]
|
return [ATTR_GENDER, ATTR_AUDIO_OUTPUT]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_options(self):
|
def default_options(self):
|
||||||
"""Return a dict include default options."""
|
"""Return a dict include default options."""
|
||||||
return {CONF_GENDER: self._gender}
|
return {
|
||||||
|
ATTR_GENDER: self._gender,
|
||||||
|
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
|
||||||
|
}
|
||||||
|
|
||||||
async def async_get_tts_audio(self, message, language, options=None):
|
async def async_get_tts_audio(self, message, language, options=None):
|
||||||
"""Load TTS from NabuCasa Cloud."""
|
"""Load TTS from NabuCasa Cloud."""
|
||||||
@ -101,10 +109,10 @@ class CloudProvider(Provider):
|
|||||||
data = await self.cloud.voice.process_tts(
|
data = await self.cloud.voice.process_tts(
|
||||||
message,
|
message,
|
||||||
language,
|
language,
|
||||||
gender=options[CONF_GENDER],
|
gender=options[ATTR_GENDER],
|
||||||
output=AudioOutput.MP3,
|
output=options[ATTR_AUDIO_OUTPUT],
|
||||||
)
|
)
|
||||||
except VoiceError:
|
except VoiceError:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
return ("mp3", data)
|
return (str(options[ATTR_AUDIO_OUTPUT]), data)
|
||||||
|
@ -59,6 +59,7 @@ ATTR_LANGUAGE = "language"
|
|||||||
ATTR_MESSAGE = "message"
|
ATTR_MESSAGE = "message"
|
||||||
ATTR_OPTIONS = "options"
|
ATTR_OPTIONS = "options"
|
||||||
ATTR_PLATFORM = "platform"
|
ATTR_PLATFORM = "platform"
|
||||||
|
ATTR_AUDIO_OUTPUT = "audio_output"
|
||||||
|
|
||||||
BASE_URL_KEY = "tts_base_url"
|
BASE_URL_KEY = "tts_base_url"
|
||||||
|
|
||||||
@ -134,6 +135,7 @@ class TTSCache(TypedDict):
|
|||||||
|
|
||||||
filename: str
|
filename: str
|
||||||
voice: bytes
|
voice: bytes
|
||||||
|
pending: asyncio.Task | None
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -495,8 +497,11 @@ class SpeechManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
|
extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
|
||||||
data = self.mem_cache[cache_key]["voice"]
|
cached = self.mem_cache[cache_key]
|
||||||
return extension, data
|
if pending := cached.get("pending"):
|
||||||
|
await pending
|
||||||
|
cached = self.mem_cache[cache_key]
|
||||||
|
return extension, cached["voice"]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _generate_cache_key(
|
def _generate_cache_key(
|
||||||
@ -527,30 +532,62 @@ class SpeechManager:
|
|||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
provider = self.providers[engine]
|
provider = self.providers[engine]
|
||||||
extension, data = await provider.async_get_tts_audio(message, language, options)
|
|
||||||
|
|
||||||
if data is None or extension is None:
|
if options is not None and ATTR_AUDIO_OUTPUT in options:
|
||||||
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
expected_extension = options[ATTR_AUDIO_OUTPUT]
|
||||||
|
else:
|
||||||
|
expected_extension = None
|
||||||
|
|
||||||
# Create file infos
|
async def get_tts_data() -> str:
|
||||||
filename = f"{cache_key}.{extension}".lower()
|
"""Handle data available."""
|
||||||
|
extension, data = await provider.async_get_tts_audio(
|
||||||
# Validate filename
|
message, language, options
|
||||||
if not _RE_VOICE_FILE.match(filename):
|
|
||||||
raise HomeAssistantError(
|
|
||||||
f"TTS filename '{filename}' from {engine} is invalid!"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to memory
|
if data is None or extension is None:
|
||||||
if extension == "mp3":
|
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")
|
||||||
data = self.write_tags(filename, data, provider, message, language, options)
|
|
||||||
self._async_store_to_memcache(cache_key, filename, data)
|
|
||||||
|
|
||||||
if cache:
|
# Create file infos
|
||||||
self.hass.async_create_task(
|
filename = f"{cache_key}.{extension}".lower()
|
||||||
self._async_save_tts_audio(cache_key, filename, data)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Validate filename
|
||||||
|
if not _RE_VOICE_FILE.match(filename):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"TTS filename '{filename}' from {engine} is invalid!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to memory
|
||||||
|
if extension == "mp3":
|
||||||
|
data = self.write_tags(
|
||||||
|
filename, data, provider, message, language, options
|
||||||
|
)
|
||||||
|
self._async_store_to_memcache(cache_key, filename, data)
|
||||||
|
|
||||||
|
if cache:
|
||||||
|
self.hass.async_create_task(
|
||||||
|
self._async_save_tts_audio(cache_key, filename, data)
|
||||||
|
)
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
audio_task = self.hass.async_create_task(get_tts_data())
|
||||||
|
|
||||||
|
if expected_extension is None:
|
||||||
|
return await audio_task
|
||||||
|
|
||||||
|
def handle_error(_future: asyncio.Future) -> None:
|
||||||
|
"""Handle error."""
|
||||||
|
if audio_task.exception():
|
||||||
|
self.mem_cache.pop(cache_key, None)
|
||||||
|
|
||||||
|
audio_task.add_done_callback(handle_error)
|
||||||
|
|
||||||
|
filename = f"{cache_key}.{expected_extension}".lower()
|
||||||
|
self.mem_cache[cache_key] = {
|
||||||
|
"filename": filename,
|
||||||
|
"voice": b"",
|
||||||
|
"pending": audio_task,
|
||||||
|
}
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
async def _async_save_tts_audio(
|
async def _async_save_tts_audio(
|
||||||
@ -601,7 +638,11 @@ class SpeechManager:
|
|||||||
self, cache_key: str, filename: str, data: bytes
|
self, cache_key: str, filename: str, data: bytes
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store data to memcache and set timer to remove it."""
|
"""Store data to memcache and set timer to remove it."""
|
||||||
self.mem_cache[cache_key] = {"filename": filename, "voice": data}
|
self.mem_cache[cache_key] = {
|
||||||
|
"filename": filename,
|
||||||
|
"voice": data,
|
||||||
|
"pending": None,
|
||||||
|
}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove_from_mem() -> None:
|
def async_remove_from_mem() -> None:
|
||||||
@ -628,7 +669,11 @@ class SpeechManager:
|
|||||||
await self._async_file_to_mem(cache_key)
|
await self._async_file_to_mem(cache_key)
|
||||||
|
|
||||||
content, _ = mimetypes.guess_type(filename)
|
content, _ = mimetypes.guess_type(filename)
|
||||||
return content, self.mem_cache[cache_key]["voice"]
|
cached = self.mem_cache[cache_key]
|
||||||
|
if pending := cached.get("pending"):
|
||||||
|
await pending
|
||||||
|
cached = self.mem_cache[cache_key]
|
||||||
|
return content, cached["voice"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_tags(
|
def write_tags(
|
||||||
|
@ -58,17 +58,17 @@ async def test_prefs_default_voice(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert provider_pref.default_language == "en-US"
|
assert provider_pref.default_language == "en-US"
|
||||||
assert provider_pref.default_options == {"gender": "female"}
|
assert provider_pref.default_options == {"gender": "female", "audio_output": "mp3"}
|
||||||
assert provider_conf.default_language == "fr-FR"
|
assert provider_conf.default_language == "fr-FR"
|
||||||
assert provider_conf.default_options == {"gender": "female"}
|
assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
|
||||||
|
|
||||||
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
|
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert provider_pref.default_language == "nl-NL"
|
assert provider_pref.default_language == "nl-NL"
|
||||||
assert provider_pref.default_options == {"gender": "male"}
|
assert provider_pref.default_options == {"gender": "male", "audio_output": "mp3"}
|
||||||
assert provider_conf.default_language == "fr-FR"
|
assert provider_conf.default_language == "fr-FR"
|
||||||
assert provider_conf.default_options == {"gender": "female"}
|
assert provider_conf.default_options == {"gender": "female", "audio_output": "mp3"}
|
||||||
|
|
||||||
|
|
||||||
async def test_provider_properties(cloud_with_prefs) -> None:
|
async def test_provider_properties(cloud_with_prefs) -> None:
|
||||||
@ -76,7 +76,7 @@ async def test_provider_properties(cloud_with_prefs) -> None:
|
|||||||
provider = await tts.async_get_engine(
|
provider = await tts.async_get_engine(
|
||||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||||
)
|
)
|
||||||
assert provider.supported_options == ["gender"]
|
assert provider.supported_options == ["gender", "audio_output"]
|
||||||
assert "nl-NL" in provider.supported_languages
|
assert "nl-NL" in provider.supported_languages
|
||||||
|
|
||||||
|
|
||||||
@ -85,5 +85,5 @@ async def test_get_tts_audio(cloud_with_prefs) -> None:
|
|||||||
provider = await tts.async_get_engine(
|
provider = await tts.async_get_engine(
|
||||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||||
)
|
)
|
||||||
assert provider.supported_options == ["gender"]
|
assert provider.supported_options == ["gender", "audio_output"]
|
||||||
assert "nl-NL" in provider.supported_languages
|
assert "nl-NL" in provider.supported_languages
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""The tests for the TTS component."""
|
"""The tests for the TTS component."""
|
||||||
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -996,3 +997,73 @@ async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
|
|||||||
await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"})
|
await tts.async_support_options(hass, "test", "en", {"invalid_option": "yo"})
|
||||||
is False
|
is False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fetching_in_async(hass: HomeAssistant, hass_client) -> None:
|
||||||
|
"""Test async fetching of data."""
|
||||||
|
tts_audio = asyncio.Future()
|
||||||
|
|
||||||
|
class ProviderWithAsyncFetching(MockProvider):
|
||||||
|
"""Provider that supports audio output option."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_options(self) -> list[str]:
|
||||||
|
"""Return list of supported options like voice, emotions."""
|
||||||
|
return [tts.ATTR_AUDIO_OUTPUT]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_options(self) -> dict[str, str]:
|
||||||
|
"""Return a dict including the default options."""
|
||||||
|
return {tts.ATTR_AUDIO_OUTPUT: "mp3"}
|
||||||
|
|
||||||
|
async def async_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||||
|
) -> tts.TtsAudioType:
|
||||||
|
return ("mp3", await tts_audio)
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule(domain="test"))
|
||||||
|
mock_platform(hass, "test.tts", MockTTS(ProviderWithAsyncFetching))
|
||||||
|
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||||
|
|
||||||
|
# Test async_get_media_source_audio
|
||||||
|
media_source_id = tts.generate_media_source_id(
|
||||||
|
hass, "test message", "test", "en", None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
task = hass.async_create_task(
|
||||||
|
tts.async_get_media_source_audio(hass, media_source_id)
|
||||||
|
)
|
||||||
|
task2 = hass.async_create_task(
|
||||||
|
tts.async_get_media_source_audio(hass, media_source_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
url = await get_media_source_url(hass, media_source_id)
|
||||||
|
client = await hass_client()
|
||||||
|
client_get_task = hass.async_create_task(client.get(url))
|
||||||
|
|
||||||
|
# Make sure that tasks are waiting for our future to resolve
|
||||||
|
done, pending = await asyncio.wait((task, task2, client_get_task), timeout=0.1)
|
||||||
|
assert len(done) == 0
|
||||||
|
assert len(pending) == 3
|
||||||
|
|
||||||
|
tts_audio.set_result(b"test")
|
||||||
|
|
||||||
|
assert await task == ("mp3", b"test")
|
||||||
|
assert await task2 == ("mp3", b"test")
|
||||||
|
|
||||||
|
req = await client_get_task
|
||||||
|
assert req.status == HTTPStatus.OK
|
||||||
|
assert await req.read() == b"test"
|
||||||
|
|
||||||
|
# Test error is not cached
|
||||||
|
media_source_id = tts.generate_media_source_id(
|
||||||
|
hass, "test message 2", "test", "en", None, None
|
||||||
|
)
|
||||||
|
tts_audio = asyncio.Future()
|
||||||
|
tts_audio.set_exception(HomeAssistantError("test error"))
|
||||||
|
with pytest.raises(HomeAssistantError):
|
||||||
|
assert await tts.async_get_media_source_audio(hass, media_source_id)
|
||||||
|
|
||||||
|
tts_audio = asyncio.Future()
|
||||||
|
tts_audio.set_result(b"test 2")
|
||||||
|
await tts.async_get_media_source_audio(hass, media_source_id) == ("mp3", b"test 2")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user