From 697e7b3a201fab27eb3bdda1ed81c698193ac58f Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 25 Sep 2022 20:53:20 -0400 Subject: [PATCH] TTS Cleanup and expose get audio (#79065) --- .../components/media_source/__init__.py | 3 +- homeassistant/components/tts/__init__.py | 168 ++++++++++++------ homeassistant/components/tts/media_source.py | 94 ++++++++-- tests/components/tts/test_init.py | 70 +++++++- 4 files changed, 250 insertions(+), 85 deletions(-) diff --git a/homeassistant/components/media_source/__init__.py b/homeassistant/components/media_source/__init__.py index a882798687e..47a5d7f6969 100644 --- a/homeassistant/components/media_source/__init__.py +++ b/homeassistant/components/media_source/__init__.py @@ -34,7 +34,7 @@ from .const import ( URI_SCHEME_REGEX, ) from .error import MediaSourceError, Unresolvable -from .models import BrowseMediaSource, MediaSourceItem, PlayMedia +from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia __all__ = [ "DOMAIN", @@ -46,6 +46,7 @@ __all__ = [ "PlayMedia", "MediaSourceItem", "Unresolvable", + "MediaSource", "MediaSourceError", "MEDIA_CLASS_MAP", "MEDIA_MIME_TYPES", diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 7def2c84bc0..757c33e2653 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -11,7 +11,7 @@ import mimetypes import os from pathlib import Path import re -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from aiohttp import web import mutagen @@ -28,7 +28,6 @@ from homeassistant.components.media_player import ( SERVICE_PLAY_MEDIA, MediaType, ) -from homeassistant.components.media_source import generate_media_source_id from homeassistant.const import ( ATTR_ENTITY_ID, CONF_DESCRIPTION, @@ -48,6 +47,7 @@ from homeassistant.util.network import normalize_url from homeassistant.util.yaml import load_yaml from .const import DOMAIN +from .media_source import generate_media_source_id, media_source_id_to_kwargs _LOGGER = logging.getLogger(__name__) @@ -74,9 +74,6 @@ DEFAULT_CACHE = True DEFAULT_CACHE_DIR = "tts" DEFAULT_TIME_MEMORY = 300 -MEM_CACHE_FILENAME = "filename" -MEM_CACHE_VOICE = "voice" - SERVICE_CLEAR_CACHE = "clear_cache" SERVICE_SAY = "say" @@ -131,6 +128,24 @@ SCHEMA_SERVICE_SAY = vol.Schema( SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({}) +class TTSCache(TypedDict): + """Cached TTS file.""" + + filename: str + voice: bytes + + +async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, +) -> tuple[str, bytes]: + """Get TTS audio as extension, data.""" + manager: SpeechManager = hass.data[DOMAIN] + return await manager.async_get_tts_audio( + **media_source_id_to_kwargs(media_source_id), + ) + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up TTS.""" tts = SpeechManager(hass) @@ -197,21 +212,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_say_handle(service: ServiceCall) -> None: """Service handle for say.""" entity_ids = service.data[ATTR_ENTITY_ID] - message = service.data[ATTR_MESSAGE] - cache = service.data.get(ATTR_CACHE) - language = service.data.get(ATTR_LANGUAGE) - options = service.data.get(ATTR_OPTIONS) - - tts.process_options(p_type, language, options) - params = { - "message": message, - } - if cache is not None: - params["cache"] = "true" if cache else "false" - if language is not None: - params["language"] = language - if options is not None: - params.update(options) await hass.services.async_call( DOMAIN_MP, @@ -219,8 +219,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: { ATTR_ENTITY_ID: entity_ids, ATTR_MEDIA_CONTENT_ID: generate_media_source_id( - DOMAIN, - str(yarl.URL.build(path=p_type, query=params)), + hass, + engine=p_type, + message=service.data[ATTR_MESSAGE], + language=service.data.get(ATTR_LANGUAGE), + options=service.data.get(ATTR_OPTIONS), + cache=service.data.get(ATTR_CACHE), ), ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC, ATTR_MEDIA_ANNOUNCE: True, @@ -296,7 +300,7 @@ class SpeechManager: self.time_memory = DEFAULT_TIME_MEMORY self.base_url: str | None = None self.file_cache: dict[str, str] = {} - self.mem_cache: dict[str, dict[str, str | bytes]] = {} + self.mem_cache: dict[str, TTSCache] = {} async def async_init_cache( self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None @@ -380,10 +384,11 @@ class SpeechManager: options = options or provider.default_options if options is not None: + supported_options = provider.supported_options or [] invalid_opts = [ opt_name for opt_name in options.keys() - if opt_name not in (provider.supported_options or []) + if opt_name not in supported_options ] if invalid_opts: raise HomeAssistantError(f"Invalid options found: {invalid_opts}") @@ -403,25 +408,25 @@ class SpeechManager: This method is a coroutine. """ language, options = self.process_options(engine, language, options) - options_key = _hash_options(options) if options else "-" - msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest() + cache_key = self._generate_cache_key(message, language, options, engine) use_cache = cache if cache is not None else self.use_cache - key = KEY_PATTERN.format( - msg_hash, language.replace("_", "-"), options_key, engine - ).lower() - # Is speech already in memory - if key in self.mem_cache: - filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME]) + if cache_key in self.mem_cache: + filename = self.mem_cache[cache_key]["filename"] # Is file store in file cache - elif use_cache and key in self.file_cache: - filename = self.file_cache[key] - self.hass.async_create_task(self.async_file_to_mem(key)) + elif use_cache and cache_key in self.file_cache: + filename = self.file_cache[cache_key] + self.hass.async_create_task(self._async_file_to_mem(cache_key)) # Load speech from provider into memory else: - filename = await self.async_get_tts_audio( - engine, key, message, use_cache, language, options + filename = await self._async_get_tts_audio( + engine, + cache_key, + message, + use_cache, + language, + options, ) return f"/api/tts_proxy/{filename}" @@ -429,13 +434,54 @@ class SpeechManager: async def async_get_tts_audio( self, engine: str, - key: str, + message: str, + cache: bool | None = None, + language: str | None = None, + options: dict | None = None, + ) -> tuple[str, bytes]: + """Fetch TTS audio.""" + language, options = self.process_options(engine, language, options) + cache_key = self._generate_cache_key(message, language, options, engine) + use_cache = cache if cache is not None else self.use_cache + + # If we have the file, load it into memory if necessary + if cache_key not in self.mem_cache: + if use_cache and cache_key in self.file_cache: + await self._async_file_to_mem(cache_key) + else: + await self._async_get_tts_audio( + engine, cache_key, message, use_cache, language, options + ) + + extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:] + data = self.mem_cache[cache_key]["voice"] + return extension, data + + @callback + def _generate_cache_key( + self, + message: str, + language: str, + options: dict | None, + engine: str, + ) -> str: + """Generate a cache key for a message.""" + options_key = _hash_options(options) if options else "-" + msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest() + return KEY_PATTERN.format( + msg_hash, language.replace("_", "-"), options_key, engine + ).lower() + + async def _async_get_tts_audio( + self, + engine: str, + cache_key: str, message: str, cache: bool, language: str, options: dict | None, ) -> str: - """Receive TTS and store for view in cache. + """Receive TTS, store for view in cache and return filename. This method is a coroutine. """ @@ -446,7 +492,7 @@ class SpeechManager: raise HomeAssistantError(f"No TTS from {engine} for '{message}'") # Create file infos - filename = f"{key}.{extension}".lower() + filename = f"{cache_key}.{extension}".lower() # Validate filename if not _RE_VOICE_FILE.match(filename): @@ -456,14 +502,18 @@ class SpeechManager: # Save to memory data = self.write_tags(filename, data, provider, message, language, options) - self._async_store_to_memcache(key, filename, data) + self._async_store_to_memcache(cache_key, filename, data) if cache: - self.hass.async_create_task(self.async_save_tts_audio(key, filename, data)) + self.hass.async_create_task( + self._async_save_tts_audio(cache_key, filename, data) + ) return filename - async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None: + async def _async_save_tts_audio( + self, cache_key: str, filename: str, data: bytes + ) -> None: """Store voice data to file and file_cache. This method is a coroutine. @@ -477,17 +527,17 @@ class SpeechManager: try: await self.hass.async_add_executor_job(save_speech) - self.file_cache[key] = filename + self.file_cache[cache_key] = filename except OSError as err: _LOGGER.error("Can't write %s: %s", filename, err) - async def async_file_to_mem(self, key: str) -> None: + async def _async_file_to_mem(self, cache_key: str) -> None: """Load voice from file cache into memory. This method is a coroutine. """ - if not (filename := self.file_cache.get(key)): - raise HomeAssistantError(f"Key {key} not in file cache!") + if not (filename := self.file_cache.get(cache_key)): + raise HomeAssistantError(f"Key {cache_key} not in file cache!") voice_file = os.path.join(self.cache_dir, filename) @@ -499,20 +549,22 @@ class SpeechManager: try: data = await self.hass.async_add_executor_job(load_speech) except OSError as err: - del self.file_cache[key] + del self.file_cache[cache_key] raise HomeAssistantError(f"Can't read {voice_file}") from err - self._async_store_to_memcache(key, filename, data) + self._async_store_to_memcache(cache_key, filename, data) @callback - def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None: + def _async_store_to_memcache( + self, cache_key: str, filename: str, data: bytes + ) -> None: """Store data to memcache and set timer to remove it.""" - self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data} + self.mem_cache[cache_key] = {"filename": filename, "voice": data} @callback def async_remove_from_mem() -> None: """Cleanup memcache.""" - self.mem_cache.pop(key, None) + self.mem_cache.pop(cache_key, None) self.hass.loop.call_later(self.time_memory, async_remove_from_mem) @@ -524,17 +576,17 @@ class SpeechManager: if not (record := _RE_VOICE_FILE.match(filename.lower())): raise HomeAssistantError("Wrong tts file format!") - key = KEY_PATTERN.format( + cache_key = KEY_PATTERN.format( record.group(1), record.group(2), record.group(3), record.group(4) ) - if key not in self.mem_cache: - if key not in self.file_cache: - raise HomeAssistantError(f"{key} not in cache!") - await self.async_file_to_mem(key) + if cache_key not in self.mem_cache: + if cache_key not in self.file_cache: + raise HomeAssistantError(f"{cache_key} not in cache!") + await self._async_file_to_mem(cache_key) content, _ = mimetypes.guess_type(filename) - return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE]) + return content, self.mem_cache[cache_key]["voice"] @staticmethod def write_tags( diff --git a/homeassistant/components/tts/media_source.py b/homeassistant/components/tts/media_source.py index eda64c804b8..c197632c11e 100644 --- a/homeassistant/components/tts/media_source.py +++ b/homeassistant/components/tts/media_source.py @@ -2,17 +2,18 @@ from __future__ import annotations import mimetypes -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, TypedDict from yarl import URL from homeassistant.components.media_player import BrowseError, MediaClass -from homeassistant.components.media_source.error import Unresolvable -from homeassistant.components.media_source.models import ( +from homeassistant.components.media_source import ( BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia, + Unresolvable, + generate_media_source_id as ms_generate_media_source_id, ) from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -29,6 +30,75 @@ async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource: return TTSMediaSource(hass) +@callback +def generate_media_source_id( + hass: HomeAssistant, + message: str, + engine: str | None = None, + language: str | None = None, + options: dict | None = None, + cache: bool | None = None, +) -> str: + """Generate a media source ID for text-to-speech.""" + manager: SpeechManager = hass.data[DOMAIN] + + if engine is not None: + pass + elif not manager.providers: + raise HomeAssistantError("No TTS providers available") + elif "cloud" in manager.providers: + engine = "cloud" + else: + engine = next(iter(manager.providers)) + + manager.process_options(engine, language, options) + params = { + "message": message, + } + if cache is not None: + params["cache"] = "true" if cache else "false" + if language is not None: + params["language"] = language + if options is not None: + params.update(options) + + return ms_generate_media_source_id( + DOMAIN, + str(URL.build(path=engine, query=params)), + ) + + +class MediaSourceOptions(TypedDict): + """Media source options.""" + + engine: str + message: str + language: str | None + options: dict | None + cache: bool | None + + +@callback +def media_source_id_to_kwargs(media_source_id: str) -> MediaSourceOptions: + """Turn a media source ID into options.""" + parsed = URL(media_source_id) + if "message" not in parsed.query: + raise Unresolvable("No message specified.") + + options = dict(parsed.query) + kwargs: MediaSourceOptions = { + "engine": parsed.name, + "message": options.pop("message"), + "language": options.pop("language", None), + "options": options, + "cache": None, + } + if "cache" in options: + kwargs["cache"] = options.pop("cache") == "true" + + return kwargs + + class TTSMediaSource(MediaSource): """Provide text-to-speech providers as media sources.""" @@ -41,24 +111,12 @@ class TTSMediaSource(MediaSource): async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: """Resolve media to a url.""" - parsed = URL(item.identifier) - if "message" not in parsed.query: - raise Unresolvable("No message specified.") - - options = dict(parsed.query) - kwargs: dict[str, Any] = { - "engine": parsed.name, - "message": options.pop("message"), - "language": options.pop("language", None), - "options": options, - } - if "cache" in options: - kwargs["cache"] = options.pop("cache") == "true" - manager: SpeechManager = self.hass.data[DOMAIN] try: - url = await manager.async_get_url_path(**kwargs) + url = await manager.async_get_url_path( + **media_source_id_to_kwargs(item.identifier) + ) except HomeAssistantError as err: raise Unresolvable(str(err)) from err diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 61c5ab00180..f521cbda58d 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -49,13 +49,18 @@ async def internal_url_mock(hass): ) -async def test_setup_component_demo(hass): +@pytest.fixture +async def setup_tts(hass): + """Mock TTS.""" + with patch("homeassistant.components.demo.async_setup", return_value=True): + assert await async_setup_component( + hass, tts.DOMAIN, {"tts": {"platform": "demo"}} + ) + await hass.async_block_till_done() + + +async def test_setup_component_demo(hass, setup_tts): """Set up the demo platform with defaults.""" - config = {tts.DOMAIN: {"platform": "demo"}} - - with assert_setup_component(1, tts.DOMAIN): - assert await async_setup_component(hass, tts.DOMAIN, config) - assert hass.services.has_service(tts.DOMAIN, "demo_say") assert hass.services.has_service(tts.DOMAIN, "clear_cache") assert f"{tts.DOMAIN}.demo" in hass.config.components @@ -421,12 +426,14 @@ async def test_setup_component_and_test_service_with_receive_voice( with assert_setup_component(1, tts.DOMAIN): assert await async_setup_component(hass, tts.DOMAIN, config) + message = "There is someone at the door." + await hass.services.async_call( tts.DOMAIN, "demo_say", { "entity_id": "media_player.something", - tts.ATTR_MESSAGE: "There is someone at the door.", + tts.ATTR_MESSAGE: message, }, blocking=True, ) @@ -440,13 +447,19 @@ async def test_setup_component_and_test_service_with_receive_voice( "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_demo.mp3", demo_data, demo_provider, - "There is someone at the door.", + message, "en", None, ) assert req.status == HTTPStatus.OK assert await req.read() == demo_data + extension, data = await tts.async_get_media_source_audio( + hass, calls[0].data[ATTR_MEDIA_CONTENT_ID] + ) + assert extension == "mp3" + assert demo_data == data + async def test_setup_component_and_test_service_with_receive_voice_german( hass, demo_provider, hass_client @@ -736,3 +749,44 @@ def test_invalid_base_url(value): """Test we catch bad base urls.""" with pytest.raises(vol.Invalid): tts.valid_base_url(value) + + +@pytest.mark.parametrize( + "engine,language,options,cache,result_engine,result_query", + ( + (None, None, None, None, "demo", ""), + (None, "de", None, None, "demo", "language=de"), + (None, "de", {"voice": "henk"}, None, "demo", "language=de&voice=henk"), + (None, "de", None, True, "demo", "cache=true&language=de"), + ), +) +async def test_generate_media_source_id( + hass, setup_tts, engine, language, options, cache, result_engine, result_query +): + """Test generating a media source ID.""" + media_source_id = tts.generate_media_source_id( + hass, "msg", engine, language, options, cache + ) + + assert media_source_id.startswith("media-source://tts/") + _, _, engine_query = media_source_id.rpartition("/") + engine, _, query = engine_query.partition("?") + assert engine == result_engine + assert query.startswith("message=msg") + assert query[12:] == result_query + + +@pytest.mark.parametrize( + "engine,language,options", + ( + ("not-loaded-engine", None, None), + (None, "unsupported-language", None), + (None, None, {"option": "not-supported"}), + ), +) +async def test_generate_media_source_id_invalid_options( + hass, setup_tts, engine, language, options +): + """Test generating a media source ID.""" + with pytest.raises(HomeAssistantError): + tts.generate_media_source_id(hass, "msg", engine, language, options, None)