mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Use HassKey in tts (#126327)
* Use HassKey in tts * Also migrate DATA_TTS_MANAGER
This commit is contained in:
parent
32f02aa3c6
commit
d40464e5d3
@ -62,6 +62,7 @@ from .const import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_TIME_MEMORY,
|
||||
DOMAIN,
|
||||
DOMAIN_DATA,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .helper import get_engine_instance
|
||||
@ -137,19 +138,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
|
||||
Returns None if no engines found.
|
||||
"""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
default_entity_id: str | None = None
|
||||
|
||||
for entity in component.entities:
|
||||
for entity in hass.data[DOMAIN_DATA].entities:
|
||||
if entity.platform and entity.platform.platform_name == "cloud":
|
||||
return entity.entity_id
|
||||
|
||||
if default_entity_id is None:
|
||||
default_entity_id = entity.entity_id
|
||||
|
||||
return default_entity_id or next(iter(manager.providers), None)
|
||||
return default_entity_id or next(iter(hass.data[DATA_TTS_MANAGER].providers), None)
|
||||
|
||||
|
||||
@callback
|
||||
@ -158,11 +156,11 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
||||
|
||||
Returns None if no engines found or invalid engine passed in.
|
||||
"""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
if engine is not None:
|
||||
if not component.get_entity(engine) and engine not in manager.providers:
|
||||
if (
|
||||
not hass.data[DOMAIN_DATA].get_entity(engine)
|
||||
and engine not in hass.data[DATA_TTS_MANAGER].providers
|
||||
):
|
||||
return None
|
||||
return engine
|
||||
|
||||
@ -179,10 +177,8 @@ async def async_support_options(
|
||||
if (engine_instance := get_engine_instance(hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
try:
|
||||
manager.process_options(engine_instance, language, options)
|
||||
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
|
||||
except HomeAssistantError:
|
||||
return False
|
||||
|
||||
@ -194,8 +190,7 @@ async def async_get_media_source_audio(
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Get TTS audio as extension, data."""
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
return await manager.async_get_tts_audio(
|
||||
return await hass.data[DATA_TTS_MANAGER].async_get_tts_audio(
|
||||
**media_source_id_to_kwargs(media_source_id),
|
||||
)
|
||||
|
||||
@ -205,14 +200,11 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
||||
"""Return a set with the union of languages supported by tts engines."""
|
||||
languages = set()
|
||||
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
for entity in component.entities:
|
||||
for entity in hass.data[DOMAIN_DATA].entities:
|
||||
for language_tag in entity.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
for tts_engine in manager.providers.values():
|
||||
for tts_engine in hass.data[DATA_TTS_MANAGER].providers.values():
|
||||
for language_tag in tts_engine.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
@ -325,7 +317,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
return False
|
||||
|
||||
hass.data[DATA_TTS_MANAGER] = tts
|
||||
component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity](
|
||||
component = hass.data[DOMAIN_DATA] = EntityComponent[TextToSpeechEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
|
||||
@ -373,14 +365,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
return await hass.data[DOMAIN_DATA].async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
return await hass.data[DOMAIN_DATA].async_unload_entry(entry)
|
||||
|
||||
|
||||
CACHED_PROPERTIES_WITH_ATTR_ = {
|
||||
@ -1105,16 +1095,13 @@ def websocket_list_engines(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List text to speech engines and, optionally, if they support a given language."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
country = msg.get("country")
|
||||
language = msg.get("language")
|
||||
providers = []
|
||||
provider_info: dict[str, Any]
|
||||
entity_domains: set[str] = set()
|
||||
|
||||
for entity in component.entities:
|
||||
for entity in hass.data[DOMAIN_DATA].entities:
|
||||
provider_info = {
|
||||
"engine_id": entity.entity_id,
|
||||
"supported_languages": entity.supported_languages,
|
||||
@ -1126,7 +1113,7 @@ def websocket_list_engines(
|
||||
providers.append(provider_info)
|
||||
if entity.platform:
|
||||
entity_domains.add(entity.platform.platform_name)
|
||||
for engine_id, provider in manager.providers.items():
|
||||
for engine_id, provider in hass.data[DATA_TTS_MANAGER].providers.items():
|
||||
provider_info = {
|
||||
"engine_id": engine_id,
|
||||
"name": provider.name,
|
||||
@ -1156,17 +1143,19 @@ def websocket_get_engine(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Get text to speech engine info."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
engine_id = msg["engine_id"]
|
||||
provider_info: dict[str, Any]
|
||||
|
||||
provider: TextToSpeechEntity | Provider | None = next(
|
||||
(entity for entity in component.entities if entity.entity_id == engine_id), None
|
||||
(
|
||||
entity
|
||||
for entity in hass.data[DOMAIN_DATA].entities
|
||||
if entity.entity_id == engine_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not provider:
|
||||
provider = manager.providers.get(engine_id)
|
||||
provider = hass.data[DATA_TTS_MANAGER].providers.get(engine_id)
|
||||
|
||||
if not provider:
|
||||
connection.send_error(
|
||||
|
@ -1,5 +1,16 @@
|
||||
"""Text-to-speech constants."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from . import SpeechManager, TextToSpeechEntity
|
||||
|
||||
ATTR_CACHE = "cache"
|
||||
ATTR_LANGUAGE = "language"
|
||||
ATTR_MESSAGE = "message"
|
||||
@ -15,7 +26,8 @@ DEFAULT_CACHE_DIR = "tts"
|
||||
DEFAULT_TIME_MEMORY = 300
|
||||
|
||||
DOMAIN = "tts"
|
||||
DOMAIN_DATA: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN)
|
||||
|
||||
DATA_TTS_MANAGER = "tts_manager"
|
||||
DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager")
|
||||
|
||||
type TtsAudioType = tuple[str | None, bytes | None]
|
||||
|
@ -5,12 +5,11 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DATA_TTS_MANAGER, DOMAIN
|
||||
from .const import DATA_TTS_MANAGER, DOMAIN_DATA
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager, TextToSpeechEntity
|
||||
from . import TextToSpeechEntity
|
||||
from .legacy import Provider
|
||||
|
||||
|
||||
@ -18,10 +17,7 @@ def get_engine_instance(
|
||||
hass: HomeAssistant, engine: str
|
||||
) -> TextToSpeechEntity | Provider | None:
|
||||
"""Get engine instance."""
|
||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
||||
|
||||
if entity := component.get_entity(engine):
|
||||
if entity := hass.data[DOMAIN_DATA].get_entity(engine):
|
||||
return entity
|
||||
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
return manager.providers.get(engine)
|
||||
return hass.data[DATA_TTS_MANAGER].providers.get(engine)
|
||||
|
@ -57,9 +57,6 @@ from .const import (
|
||||
from .media_source import generate_media_source_id
|
||||
from .models import Voice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_SERVICE_NAME = "service_name"
|
||||
@ -105,8 +102,6 @@ async def async_setup_legacy(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> list[Coroutine[Any, Any, None]]:
|
||||
"""Set up legacy text-to-speech providers."""
|
||||
tts: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
# Load service descriptions from tts/services.yaml
|
||||
services_yaml = Path(__file__).parent / "services.yaml"
|
||||
services_dict = await hass.async_add_executor_job(
|
||||
@ -147,7 +142,9 @@ async def async_setup_legacy(
|
||||
_LOGGER.error("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
||||
tts.async_register_legacy_engine(p_type, provider, p_config)
|
||||
hass.data[DATA_TTS_MANAGER].async_register_legacy_engine(
|
||||
p_type, provider, p_config
|
||||
)
|
||||
except Exception:
|
||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||
return
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
from typing import TypedDict
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@ -18,14 +18,10 @@ from homeassistant.components.media_source import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DATA_TTS_MANAGER, DOMAIN
|
||||
from .const import DATA_TTS_MANAGER, DOMAIN, DOMAIN_DATA
|
||||
from .helper import get_engine_instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import SpeechManager, TextToSpeechEntity
|
||||
|
||||
|
||||
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
|
||||
"""Set up tts media source."""
|
||||
@ -44,8 +40,6 @@ def generate_media_source_id(
|
||||
"""Generate a media source ID for text-to-speech."""
|
||||
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
|
||||
|
||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
if (engine := async_resolve_engine(hass, engine)) is None:
|
||||
raise HomeAssistantError("Invalid TTS provider selected")
|
||||
|
||||
@ -53,7 +47,7 @@ def generate_media_source_id(
|
||||
# We raise above if the engine is not resolved, so engine_instance can't be None
|
||||
assert engine_instance is not None
|
||||
|
||||
manager.process_options(engine_instance, language, options)
|
||||
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
|
||||
params = {
|
||||
"message": message,
|
||||
}
|
||||
@ -113,10 +107,8 @@ class TTSMediaSource(MediaSource):
|
||||
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
|
||||
|
||||
try:
|
||||
url = await manager.async_get_url_path(
|
||||
url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path(
|
||||
**media_source_id_to_kwargs(item.identifier)
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
@ -136,10 +128,12 @@ class TTSMediaSource(MediaSource):
|
||||
return self._engine_item(engine, params)
|
||||
|
||||
# Root. List providers.
|
||||
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
|
||||
component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN]
|
||||
children = [self._engine_item(engine) for engine in manager.providers] + [
|
||||
self._engine_item(entity.entity_id) for entity in component.entities
|
||||
children = [
|
||||
self._engine_item(engine)
|
||||
for engine in self.hass.data[DATA_TTS_MANAGER].providers
|
||||
] + [
|
||||
self._engine_item(entity.entity_id)
|
||||
for entity in self.hass.data[DOMAIN_DATA].entities
|
||||
]
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
|
Loading…
x
Reference in New Issue
Block a user