Use HassKey in tts (#126327)

* Use HassKey in tts

* Also migrate DATA_TTS_MANAGER
This commit is contained in:
epenet 2024-09-21 13:14:27 +02:00 committed by GitHub
parent 32f02aa3c6
commit d40464e5d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 53 additions and 65 deletions

View File

@ -62,6 +62,7 @@ from .const import (
DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY,
DOMAIN,
DOMAIN_DATA,
TtsAudioType,
)
from .helper import get_engine_instance
@ -137,19 +138,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
Returns None if no engines found.
"""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
default_entity_id: str | None = None
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id
if default_entity_id is None:
default_entity_id = entity.entity_id
return default_entity_id or next(iter(manager.providers), None)
return default_entity_id or next(iter(hass.data[DATA_TTS_MANAGER].providers), None)
@callback
@ -158,11 +156,11 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
Returns None if no engines found or invalid engine passed in.
"""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if engine is not None:
if not component.get_entity(engine) and engine not in manager.providers:
if (
not hass.data[DOMAIN_DATA].get_entity(engine)
and engine not in hass.data[DATA_TTS_MANAGER].providers
):
return None
return engine
@ -179,10 +177,8 @@ async def async_support_options(
if (engine_instance := get_engine_instance(hass, engine)) is None:
raise HomeAssistantError(f"Provider {engine} not found")
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
try:
manager.process_options(engine_instance, language, options)
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
except HomeAssistantError:
return False
@ -194,8 +190,7 @@ async def async_get_media_source_audio(
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
return await manager.async_get_tts_audio(
return await hass.data[DATA_TTS_MANAGER].async_get_tts_audio(
**media_source_id_to_kwargs(media_source_id),
)
@ -205,14 +200,11 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by tts engines."""
languages = set()
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
for language_tag in entity.supported_languages:
languages.add(language_tag)
for tts_engine in manager.providers.values():
for tts_engine in hass.data[DATA_TTS_MANAGER].providers.values():
for language_tag in tts_engine.supported_languages:
languages.add(language_tag)
@ -325,7 +317,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return False
hass.data[DATA_TTS_MANAGER] = tts
component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity](
component = hass.data[DOMAIN_DATA] = EntityComponent[TextToSpeechEntity](
_LOGGER, DOMAIN, hass
)
@ -373,14 +365,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
return await hass.data[DOMAIN_DATA].async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
return await hass.data[DOMAIN_DATA].async_unload_entry(entry)
CACHED_PROPERTIES_WITH_ATTR_ = {
@ -1105,16 +1095,13 @@ def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List text to speech engines and, optionally, if they support a given language."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
country = msg.get("country")
language = msg.get("language")
providers = []
provider_info: dict[str, Any]
entity_domains: set[str] = set()
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
provider_info = {
"engine_id": entity.entity_id,
"supported_languages": entity.supported_languages,
@ -1126,7 +1113,7 @@ def websocket_list_engines(
providers.append(provider_info)
if entity.platform:
entity_domains.add(entity.platform.platform_name)
for engine_id, provider in manager.providers.items():
for engine_id, provider in hass.data[DATA_TTS_MANAGER].providers.items():
provider_info = {
"engine_id": engine_id,
"name": provider.name,
@ -1156,17 +1143,19 @@ def websocket_get_engine(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Get text to speech engine info."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
engine_id = msg["engine_id"]
provider_info: dict[str, Any]
provider: TextToSpeechEntity | Provider | None = next(
(entity for entity in component.entities if entity.entity_id == engine_id), None
(
entity
for entity in hass.data[DOMAIN_DATA].entities
if entity.entity_id == engine_id
),
None,
)
if not provider:
provider = manager.providers.get(engine_id)
provider = hass.data[DATA_TTS_MANAGER].providers.get(engine_id)
if not provider:
connection.send_error(

View File

@ -1,5 +1,16 @@
"""Text-to-speech constants."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.helpers.entity_component import EntityComponent
from . import SpeechManager, TextToSpeechEntity
ATTR_CACHE = "cache"
ATTR_LANGUAGE = "language"
ATTR_MESSAGE = "message"
@ -15,7 +26,8 @@ DEFAULT_CACHE_DIR = "tts"
DEFAULT_TIME_MEMORY = 300
DOMAIN = "tts"
DOMAIN_DATA: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN)
DATA_TTS_MANAGER = "tts_manager"
DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager")
type TtsAudioType = tuple[str | None, bytes | None]

View File

@ -5,12 +5,11 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_component import EntityComponent
from .const import DATA_TTS_MANAGER, DOMAIN
from .const import DATA_TTS_MANAGER, DOMAIN_DATA
if TYPE_CHECKING:
from . import SpeechManager, TextToSpeechEntity
from . import TextToSpeechEntity
from .legacy import Provider
@ -18,10 +17,7 @@ def get_engine_instance(
hass: HomeAssistant, engine: str
) -> TextToSpeechEntity | Provider | None:
"""Get engine instance."""
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
if entity := component.get_entity(engine):
if entity := hass.data[DOMAIN_DATA].get_entity(engine):
return entity
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
return manager.providers.get(engine)
return hass.data[DATA_TTS_MANAGER].providers.get(engine)

View File

@ -57,9 +57,6 @@ from .const import (
from .media_source import generate_media_source_id
from .models import Voice
if TYPE_CHECKING:
from . import SpeechManager
_LOGGER = logging.getLogger(__name__)
CONF_SERVICE_NAME = "service_name"
@ -105,8 +102,6 @@ async def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy text-to-speech providers."""
tts: SpeechManager = hass.data[DATA_TTS_MANAGER]
# Load service descriptions from tts/services.yaml
services_yaml = Path(__file__).parent / "services.yaml"
services_dict = await hass.async_add_executor_job(
@ -147,7 +142,9 @@ async def async_setup_legacy(
_LOGGER.error("Error setting up platform: %s", p_type)
return
tts.async_register_legacy_engine(p_type, provider, p_config)
hass.data[DATA_TTS_MANAGER].async_register_legacy_engine(
p_type, provider, p_config
)
except Exception:
_LOGGER.exception("Error setting up platform: %s", p_type)
return

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import mimetypes
from typing import TYPE_CHECKING, TypedDict
from typing import TypedDict
from yarl import URL
@ -18,14 +18,10 @@ from homeassistant.components.media_source import (
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_component import EntityComponent
from .const import DATA_TTS_MANAGER, DOMAIN
from .const import DATA_TTS_MANAGER, DOMAIN, DOMAIN_DATA
from .helper import get_engine_instance
if TYPE_CHECKING:
from . import SpeechManager, TextToSpeechEntity
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
"""Set up tts media source."""
@ -44,8 +40,6 @@ def generate_media_source_id(
"""Generate a media source ID for text-to-speech."""
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
if (engine := async_resolve_engine(hass, engine)) is None:
raise HomeAssistantError("Invalid TTS provider selected")
@ -53,7 +47,7 @@ def generate_media_source_id(
# We raise above if the engine is not resolved, so engine_instance can't be None
assert engine_instance is not None
manager.process_options(engine_instance, language, options)
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
params = {
"message": message,
}
@ -113,10 +107,8 @@ class TTSMediaSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
try:
url = await manager.async_get_url_path(
url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path(
**media_source_id_to_kwargs(item.identifier)
)
except HomeAssistantError as err:
@ -136,10 +128,12 @@ class TTSMediaSource(MediaSource):
return self._engine_item(engine, params)
# Root. List providers.
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN]
children = [self._engine_item(engine) for engine in manager.providers] + [
self._engine_item(entity.entity_id) for entity in component.entities
children = [
self._engine_item(engine)
for engine in self.hass.data[DATA_TTS_MANAGER].providers
] + [
self._engine_item(entity.entity_id)
for entity in self.hass.data[DOMAIN_DATA].entities
]
return BrowseMediaSource(
domain=DOMAIN,