From d40464e5d305f214aa6dbec912bcec8b8f8ac4bf Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Sat, 21 Sep 2024 13:14:27 +0200 Subject: [PATCH] Use HassKey in tts (#126327) * Use HassKey in tts * Also migrate DATA_TTS_MANAGER --- homeassistant/components/tts/__init__.py | 57 ++++++++------------ homeassistant/components/tts/const.py | 14 ++++- homeassistant/components/tts/helper.py | 12 ++--- homeassistant/components/tts/legacy.py | 9 ++-- homeassistant/components/tts/media_source.py | 26 ++++----- 5 files changed, 53 insertions(+), 65 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 9e3d9f65a76..5ecbe15601d 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -62,6 +62,7 @@ from .const import ( DEFAULT_CACHE_DIR, DEFAULT_TIME_MEMORY, DOMAIN, + DOMAIN_DATA, TtsAudioType, ) from .helper import get_engine_instance @@ -137,19 +138,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None: Returns None if no engines found. """ - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - default_entity_id: str | None = None - for entity in component.entities: + for entity in hass.data[DOMAIN_DATA].entities: if entity.platform and entity.platform.platform_name == "cloud": return entity.entity_id if default_entity_id is None: default_entity_id = entity.entity_id - return default_entity_id or next(iter(manager.providers), None) + return default_entity_id or next(iter(hass.data[DATA_TTS_MANAGER].providers), None) @callback @@ -158,11 +156,11 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None: Returns None if no engines found or invalid engine passed in. """ - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - if engine is not None: - if not component.get_entity(engine) and engine not in manager.providers: + if ( + not hass.data[DOMAIN_DATA].get_entity(engine) + and engine not in hass.data[DATA_TTS_MANAGER].providers + ): return None return engine @@ -179,10 +177,8 @@ async def async_support_options( if (engine_instance := get_engine_instance(hass, engine)) is None: raise HomeAssistantError(f"Provider {engine} not found") - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - try: - manager.process_options(engine_instance, language, options) + hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options) except HomeAssistantError: return False @@ -194,8 +190,7 @@ async def async_get_media_source_audio( media_source_id: str, ) -> tuple[str, bytes]: """Get TTS audio as extension, data.""" - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - return await manager.async_get_tts_audio( + return await hass.data[DATA_TTS_MANAGER].async_get_tts_audio( **media_source_id_to_kwargs(media_source_id), ) @@ -205,14 +200,11 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]: """Return a set with the union of languages supported by tts engines.""" languages = set() - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - - for entity in component.entities: + for entity in hass.data[DOMAIN_DATA].entities: for language_tag in entity.supported_languages: languages.add(language_tag) - for tts_engine in manager.providers.values(): + for tts_engine in hass.data[DATA_TTS_MANAGER].providers.values(): for language_tag in tts_engine.supported_languages: languages.add(language_tag) @@ -325,7 +317,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return False hass.data[DATA_TTS_MANAGER] = tts - component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity]( + component = hass.data[DOMAIN_DATA] = EntityComponent[TextToSpeechEntity]( _LOGGER, DOMAIN, hass ) @@ -373,14 +365,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - return await component.async_setup_entry(entry) + return await hass.data[DOMAIN_DATA].async_setup_entry(entry) async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - return await component.async_unload_entry(entry) + return await hass.data[DOMAIN_DATA].async_unload_entry(entry) CACHED_PROPERTIES_WITH_ATTR_ = { @@ -1105,16 +1095,13 @@ def websocket_list_engines( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """List text to speech engines and, optionally, if they support a given language.""" - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - country = msg.get("country") language = msg.get("language") providers = [] provider_info: dict[str, Any] entity_domains: set[str] = set() - for entity in component.entities: + for entity in hass.data[DOMAIN_DATA].entities: provider_info = { "engine_id": entity.entity_id, "supported_languages": entity.supported_languages, @@ -1126,7 +1113,7 @@ def websocket_list_engines( providers.append(provider_info) if entity.platform: entity_domains.add(entity.platform.platform_name) - for engine_id, provider in manager.providers.items(): + for engine_id, provider in hass.data[DATA_TTS_MANAGER].providers.items(): provider_info = { "engine_id": engine_id, "name": provider.name, @@ -1156,17 +1143,19 @@ def websocket_get_engine( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Get text to speech engine info.""" - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - engine_id = msg["engine_id"] provider_info: dict[str, Any] provider: TextToSpeechEntity | Provider | None = next( - (entity for entity in component.entities if entity.entity_id == engine_id), None + ( + entity + for entity in hass.data[DOMAIN_DATA].entities + if entity.entity_id == engine_id + ), + None, ) if not provider: - provider = manager.providers.get(engine_id) + provider = hass.data[DATA_TTS_MANAGER].providers.get(engine_id) if not provider: connection.send_error( diff --git a/homeassistant/components/tts/const.py b/homeassistant/components/tts/const.py index ab22a44cab6..b465dfb15dd 100644 --- a/homeassistant/components/tts/const.py +++ b/homeassistant/components/tts/const.py @@ -1,5 +1,16 @@ """Text-to-speech constants.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from homeassistant.util.hass_dict import HassKey + +if TYPE_CHECKING: + from homeassistant.helpers.entity_component import EntityComponent + + from . import SpeechManager, TextToSpeechEntity + ATTR_CACHE = "cache" ATTR_LANGUAGE = "language" ATTR_MESSAGE = "message" @@ -15,7 +26,8 @@ DEFAULT_CACHE_DIR = "tts" DEFAULT_TIME_MEMORY = 300 DOMAIN = "tts" +DOMAIN_DATA: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN) -DATA_TTS_MANAGER = "tts_manager" +DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager") type TtsAudioType = tuple[str | None, bytes | None] diff --git a/homeassistant/components/tts/helper.py b/homeassistant/components/tts/helper.py index 4b5ef168550..41b938f7e0b 100644 --- a/homeassistant/components/tts/helper.py +++ b/homeassistant/components/tts/helper.py @@ -5,12 +5,11 @@ from __future__ import annotations from typing import TYPE_CHECKING from homeassistant.core import HomeAssistant -from homeassistant.helpers.entity_component import EntityComponent -from .const import DATA_TTS_MANAGER, DOMAIN +from .const import DATA_TTS_MANAGER, DOMAIN_DATA if TYPE_CHECKING: - from . import SpeechManager, TextToSpeechEntity + from . import TextToSpeechEntity from .legacy import Provider @@ -18,10 +17,7 @@ def get_engine_instance( hass: HomeAssistant, engine: str ) -> TextToSpeechEntity | Provider | None: """Get engine instance.""" - component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN] - - if entity := component.get_entity(engine): + if entity := hass.data[DOMAIN_DATA].get_entity(engine): return entity - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - return manager.providers.get(engine) + return hass.data[DATA_TTS_MANAGER].providers.get(engine) diff --git a/homeassistant/components/tts/legacy.py b/homeassistant/components/tts/legacy.py index e36a1227603..54ea89cb674 100644 --- a/homeassistant/components/tts/legacy.py +++ b/homeassistant/components/tts/legacy.py @@ -57,9 +57,6 @@ from .const import ( from .media_source import generate_media_source_id from .models import Voice -if TYPE_CHECKING: - from . import SpeechManager - _LOGGER = logging.getLogger(__name__) CONF_SERVICE_NAME = "service_name" @@ -105,8 +102,6 @@ async def async_setup_legacy( hass: HomeAssistant, config: ConfigType ) -> list[Coroutine[Any, Any, None]]: """Set up legacy text-to-speech providers.""" - tts: SpeechManager = hass.data[DATA_TTS_MANAGER] - # Load service descriptions from tts/services.yaml services_yaml = Path(__file__).parent / "services.yaml" services_dict = await hass.async_add_executor_job( @@ -147,7 +142,9 @@ async def async_setup_legacy( _LOGGER.error("Error setting up platform: %s", p_type) return - tts.async_register_legacy_engine(p_type, provider, p_config) + hass.data[DATA_TTS_MANAGER].async_register_legacy_engine( + p_type, provider, p_config + ) except Exception: _LOGGER.exception("Error setting up platform: %s", p_type) return diff --git a/homeassistant/components/tts/media_source.py b/homeassistant/components/tts/media_source.py index a907fc485c9..13c37681259 100644 --- a/homeassistant/components/tts/media_source.py +++ b/homeassistant/components/tts/media_source.py @@ -3,7 +3,7 @@ from __future__ import annotations import mimetypes -from typing import TYPE_CHECKING, TypedDict +from typing import TypedDict from yarl import URL @@ -18,14 +18,10 @@ from homeassistant.components.media_source import ( ) from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.entity_component import EntityComponent -from .const import DATA_TTS_MANAGER, DOMAIN +from .const import DATA_TTS_MANAGER, DOMAIN, DOMAIN_DATA from .helper import get_engine_instance -if TYPE_CHECKING: - from . import SpeechManager, TextToSpeechEntity - async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource: """Set up tts media source.""" @@ -44,8 +40,6 @@ def generate_media_source_id( """Generate a media source ID for text-to-speech.""" from . import async_resolve_engine # pylint: disable=import-outside-toplevel - manager: SpeechManager = hass.data[DATA_TTS_MANAGER] - if (engine := async_resolve_engine(hass, engine)) is None: raise HomeAssistantError("Invalid TTS provider selected") @@ -53,7 +47,7 @@ def generate_media_source_id( # We raise above if the engine is not resolved, so engine_instance can't be None assert engine_instance is not None - manager.process_options(engine_instance, language, options) + hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options) params = { "message": message, } @@ -113,10 +107,8 @@ class TTSMediaSource(MediaSource): async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: """Resolve media to a url.""" - manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER] - try: - url = await manager.async_get_url_path( + url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path( **media_source_id_to_kwargs(item.identifier) ) except HomeAssistantError as err: @@ -136,10 +128,12 @@ class TTSMediaSource(MediaSource): return self._engine_item(engine, params) # Root. List providers. - manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER] - component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN] - children = [self._engine_item(engine) for engine in manager.providers] + [ - self._engine_item(entity.entity_id) for entity in component.entities + children = [ + self._engine_item(engine) + for engine in self.hass.data[DATA_TTS_MANAGER].providers + ] + [ + self._engine_item(entity.entity_id) + for entity in self.hass.data[DOMAIN_DATA].entities ] return BrowseMediaSource( domain=DOMAIN,