mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Use HassKey in tts (#126327)
* Use HassKey in tts * Also migrate DATA_TTS_MANAGER
This commit is contained in:
parent
32f02aa3c6
commit
d40464e5d3
@ -62,6 +62,7 @@ from .const import (
|
|||||||
DEFAULT_CACHE_DIR,
|
DEFAULT_CACHE_DIR,
|
||||||
DEFAULT_TIME_MEMORY,
|
DEFAULT_TIME_MEMORY,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
DOMAIN_DATA,
|
||||||
TtsAudioType,
|
TtsAudioType,
|
||||||
)
|
)
|
||||||
from .helper import get_engine_instance
|
from .helper import get_engine_instance
|
||||||
@ -137,19 +138,16 @@ def async_default_engine(hass: HomeAssistant) -> str | None:
|
|||||||
|
|
||||||
Returns None if no engines found.
|
Returns None if no engines found.
|
||||||
"""
|
"""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
default_entity_id: str | None = None
|
default_entity_id: str | None = None
|
||||||
|
|
||||||
for entity in component.entities:
|
for entity in hass.data[DOMAIN_DATA].entities:
|
||||||
if entity.platform and entity.platform.platform_name == "cloud":
|
if entity.platform and entity.platform.platform_name == "cloud":
|
||||||
return entity.entity_id
|
return entity.entity_id
|
||||||
|
|
||||||
if default_entity_id is None:
|
if default_entity_id is None:
|
||||||
default_entity_id = entity.entity_id
|
default_entity_id = entity.entity_id
|
||||||
|
|
||||||
return default_entity_id or next(iter(manager.providers), None)
|
return default_entity_id or next(iter(hass.data[DATA_TTS_MANAGER].providers), None)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -158,11 +156,11 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
|||||||
|
|
||||||
Returns None if no engines found or invalid engine passed in.
|
Returns None if no engines found or invalid engine passed in.
|
||||||
"""
|
"""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
if engine is not None:
|
if engine is not None:
|
||||||
if not component.get_entity(engine) and engine not in manager.providers:
|
if (
|
||||||
|
not hass.data[DOMAIN_DATA].get_entity(engine)
|
||||||
|
and engine not in hass.data[DATA_TTS_MANAGER].providers
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
@ -179,10 +177,8 @@ async def async_support_options(
|
|||||||
if (engine_instance := get_engine_instance(hass, engine)) is None:
|
if (engine_instance := get_engine_instance(hass, engine)) is None:
|
||||||
raise HomeAssistantError(f"Provider {engine} not found")
|
raise HomeAssistantError(f"Provider {engine} not found")
|
||||||
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
manager.process_options(engine_instance, language, options)
|
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
|
||||||
except HomeAssistantError:
|
except HomeAssistantError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -194,8 +190,7 @@ async def async_get_media_source_audio(
|
|||||||
media_source_id: str,
|
media_source_id: str,
|
||||||
) -> tuple[str, bytes]:
|
) -> tuple[str, bytes]:
|
||||||
"""Get TTS audio as extension, data."""
|
"""Get TTS audio as extension, data."""
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
return await hass.data[DATA_TTS_MANAGER].async_get_tts_audio(
|
||||||
return await manager.async_get_tts_audio(
|
|
||||||
**media_source_id_to_kwargs(media_source_id),
|
**media_source_id_to_kwargs(media_source_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -205,14 +200,11 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
|||||||
"""Return a set with the union of languages supported by tts engines."""
|
"""Return a set with the union of languages supported by tts engines."""
|
||||||
languages = set()
|
languages = set()
|
||||||
|
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
for entity in hass.data[DOMAIN_DATA].entities:
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
for entity in component.entities:
|
|
||||||
for language_tag in entity.supported_languages:
|
for language_tag in entity.supported_languages:
|
||||||
languages.add(language_tag)
|
languages.add(language_tag)
|
||||||
|
|
||||||
for tts_engine in manager.providers.values():
|
for tts_engine in hass.data[DATA_TTS_MANAGER].providers.values():
|
||||||
for language_tag in tts_engine.supported_languages:
|
for language_tag in tts_engine.supported_languages:
|
||||||
languages.add(language_tag)
|
languages.add(language_tag)
|
||||||
|
|
||||||
@ -325,7 +317,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
hass.data[DATA_TTS_MANAGER] = tts
|
hass.data[DATA_TTS_MANAGER] = tts
|
||||||
component = hass.data[DOMAIN] = EntityComponent[TextToSpeechEntity](
|
component = hass.data[DOMAIN_DATA] = EntityComponent[TextToSpeechEntity](
|
||||||
_LOGGER, DOMAIN, hass
|
_LOGGER, DOMAIN, hass
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -373,14 +365,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Set up a config entry."""
|
"""Set up a config entry."""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
return await hass.data[DOMAIN_DATA].async_setup_entry(entry)
|
||||||
return await component.async_setup_entry(entry)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload a config entry."""
|
"""Unload a config entry."""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
return await hass.data[DOMAIN_DATA].async_unload_entry(entry)
|
||||||
return await component.async_unload_entry(entry)
|
|
||||||
|
|
||||||
|
|
||||||
CACHED_PROPERTIES_WITH_ATTR_ = {
|
CACHED_PROPERTIES_WITH_ATTR_ = {
|
||||||
@ -1105,16 +1095,13 @@ def websocket_list_engines(
|
|||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List text to speech engines and, optionally, if they support a given language."""
|
"""List text to speech engines and, optionally, if they support a given language."""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
country = msg.get("country")
|
country = msg.get("country")
|
||||||
language = msg.get("language")
|
language = msg.get("language")
|
||||||
providers = []
|
providers = []
|
||||||
provider_info: dict[str, Any]
|
provider_info: dict[str, Any]
|
||||||
entity_domains: set[str] = set()
|
entity_domains: set[str] = set()
|
||||||
|
|
||||||
for entity in component.entities:
|
for entity in hass.data[DOMAIN_DATA].entities:
|
||||||
provider_info = {
|
provider_info = {
|
||||||
"engine_id": entity.entity_id,
|
"engine_id": entity.entity_id,
|
||||||
"supported_languages": entity.supported_languages,
|
"supported_languages": entity.supported_languages,
|
||||||
@ -1126,7 +1113,7 @@ def websocket_list_engines(
|
|||||||
providers.append(provider_info)
|
providers.append(provider_info)
|
||||||
if entity.platform:
|
if entity.platform:
|
||||||
entity_domains.add(entity.platform.platform_name)
|
entity_domains.add(entity.platform.platform_name)
|
||||||
for engine_id, provider in manager.providers.items():
|
for engine_id, provider in hass.data[DATA_TTS_MANAGER].providers.items():
|
||||||
provider_info = {
|
provider_info = {
|
||||||
"engine_id": engine_id,
|
"engine_id": engine_id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
@ -1156,17 +1143,19 @@ def websocket_get_engine(
|
|||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Get text to speech engine info."""
|
"""Get text to speech engine info."""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
engine_id = msg["engine_id"]
|
engine_id = msg["engine_id"]
|
||||||
provider_info: dict[str, Any]
|
provider_info: dict[str, Any]
|
||||||
|
|
||||||
provider: TextToSpeechEntity | Provider | None = next(
|
provider: TextToSpeechEntity | Provider | None = next(
|
||||||
(entity for entity in component.entities if entity.entity_id == engine_id), None
|
(
|
||||||
|
entity
|
||||||
|
for entity in hass.data[DOMAIN_DATA].entities
|
||||||
|
if entity.entity_id == engine_id
|
||||||
|
),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
if not provider:
|
if not provider:
|
||||||
provider = manager.providers.get(engine_id)
|
provider = hass.data[DATA_TTS_MANAGER].providers.get(engine_id)
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
|
@ -1,5 +1,16 @@
|
|||||||
"""Text-to-speech constants."""
|
"""Text-to-speech constants."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
|
||||||
|
from . import SpeechManager, TextToSpeechEntity
|
||||||
|
|
||||||
ATTR_CACHE = "cache"
|
ATTR_CACHE = "cache"
|
||||||
ATTR_LANGUAGE = "language"
|
ATTR_LANGUAGE = "language"
|
||||||
ATTR_MESSAGE = "message"
|
ATTR_MESSAGE = "message"
|
||||||
@ -15,7 +26,8 @@ DEFAULT_CACHE_DIR = "tts"
|
|||||||
DEFAULT_TIME_MEMORY = 300
|
DEFAULT_TIME_MEMORY = 300
|
||||||
|
|
||||||
DOMAIN = "tts"
|
DOMAIN = "tts"
|
||||||
|
DOMAIN_DATA: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN)
|
||||||
|
|
||||||
DATA_TTS_MANAGER = "tts_manager"
|
DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager")
|
||||||
|
|
||||||
type TtsAudioType = tuple[str | None, bytes | None]
|
type TtsAudioType = tuple[str | None, bytes | None]
|
||||||
|
@ -5,12 +5,11 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
|
||||||
|
|
||||||
from .const import DATA_TTS_MANAGER, DOMAIN
|
from .const import DATA_TTS_MANAGER, DOMAIN_DATA
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from . import SpeechManager, TextToSpeechEntity
|
from . import TextToSpeechEntity
|
||||||
from .legacy import Provider
|
from .legacy import Provider
|
||||||
|
|
||||||
|
|
||||||
@ -18,10 +17,7 @@ def get_engine_instance(
|
|||||||
hass: HomeAssistant, engine: str
|
hass: HomeAssistant, engine: str
|
||||||
) -> TextToSpeechEntity | Provider | None:
|
) -> TextToSpeechEntity | Provider | None:
|
||||||
"""Get engine instance."""
|
"""Get engine instance."""
|
||||||
component: EntityComponent[TextToSpeechEntity] = hass.data[DOMAIN]
|
if entity := hass.data[DOMAIN_DATA].get_entity(engine):
|
||||||
|
|
||||||
if entity := component.get_entity(engine):
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
return hass.data[DATA_TTS_MANAGER].providers.get(engine)
|
||||||
return manager.providers.get(engine)
|
|
||||||
|
@ -57,9 +57,6 @@ from .const import (
|
|||||||
from .media_source import generate_media_source_id
|
from .media_source import generate_media_source_id
|
||||||
from .models import Voice
|
from .models import Voice
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from . import SpeechManager
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
CONF_SERVICE_NAME = "service_name"
|
CONF_SERVICE_NAME = "service_name"
|
||||||
@ -105,8 +102,6 @@ async def async_setup_legacy(
|
|||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> list[Coroutine[Any, Any, None]]:
|
) -> list[Coroutine[Any, Any, None]]:
|
||||||
"""Set up legacy text-to-speech providers."""
|
"""Set up legacy text-to-speech providers."""
|
||||||
tts: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
# Load service descriptions from tts/services.yaml
|
# Load service descriptions from tts/services.yaml
|
||||||
services_yaml = Path(__file__).parent / "services.yaml"
|
services_yaml = Path(__file__).parent / "services.yaml"
|
||||||
services_dict = await hass.async_add_executor_job(
|
services_dict = await hass.async_add_executor_job(
|
||||||
@ -147,7 +142,9 @@ async def async_setup_legacy(
|
|||||||
_LOGGER.error("Error setting up platform: %s", p_type)
|
_LOGGER.error("Error setting up platform: %s", p_type)
|
||||||
return
|
return
|
||||||
|
|
||||||
tts.async_register_legacy_engine(p_type, provider, p_config)
|
hass.data[DATA_TTS_MANAGER].async_register_legacy_engine(
|
||||||
|
p_type, provider, p_config
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
_LOGGER.exception("Error setting up platform: %s", p_type)
|
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||||
return
|
return
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
@ -18,14 +18,10 @@ from homeassistant.components.media_source import (
|
|||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
|
||||||
|
|
||||||
from .const import DATA_TTS_MANAGER, DOMAIN
|
from .const import DATA_TTS_MANAGER, DOMAIN, DOMAIN_DATA
|
||||||
from .helper import get_engine_instance
|
from .helper import get_engine_instance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from . import SpeechManager, TextToSpeechEntity
|
|
||||||
|
|
||||||
|
|
||||||
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
|
async def async_get_media_source(hass: HomeAssistant) -> TTSMediaSource:
|
||||||
"""Set up tts media source."""
|
"""Set up tts media source."""
|
||||||
@ -44,8 +40,6 @@ def generate_media_source_id(
|
|||||||
"""Generate a media source ID for text-to-speech."""
|
"""Generate a media source ID for text-to-speech."""
|
||||||
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
|
from . import async_resolve_engine # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
manager: SpeechManager = hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
if (engine := async_resolve_engine(hass, engine)) is None:
|
if (engine := async_resolve_engine(hass, engine)) is None:
|
||||||
raise HomeAssistantError("Invalid TTS provider selected")
|
raise HomeAssistantError("Invalid TTS provider selected")
|
||||||
|
|
||||||
@ -53,7 +47,7 @@ def generate_media_source_id(
|
|||||||
# We raise above if the engine is not resolved, so engine_instance can't be None
|
# We raise above if the engine is not resolved, so engine_instance can't be None
|
||||||
assert engine_instance is not None
|
assert engine_instance is not None
|
||||||
|
|
||||||
manager.process_options(engine_instance, language, options)
|
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
|
||||||
params = {
|
params = {
|
||||||
"message": message,
|
"message": message,
|
||||||
}
|
}
|
||||||
@ -113,10 +107,8 @@ class TTSMediaSource(MediaSource):
|
|||||||
|
|
||||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||||
"""Resolve media to a url."""
|
"""Resolve media to a url."""
|
||||||
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = await manager.async_get_url_path(
|
url = await self.hass.data[DATA_TTS_MANAGER].async_get_url_path(
|
||||||
**media_source_id_to_kwargs(item.identifier)
|
**media_source_id_to_kwargs(item.identifier)
|
||||||
)
|
)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
@ -136,10 +128,12 @@ class TTSMediaSource(MediaSource):
|
|||||||
return self._engine_item(engine, params)
|
return self._engine_item(engine, params)
|
||||||
|
|
||||||
# Root. List providers.
|
# Root. List providers.
|
||||||
manager: SpeechManager = self.hass.data[DATA_TTS_MANAGER]
|
children = [
|
||||||
component: EntityComponent[TextToSpeechEntity] = self.hass.data[DOMAIN]
|
self._engine_item(engine)
|
||||||
children = [self._engine_item(engine) for engine in manager.providers] + [
|
for engine in self.hass.data[DATA_TTS_MANAGER].providers
|
||||||
self._engine_item(entity.entity_id) for entity in component.entities
|
] + [
|
||||||
|
self._engine_item(entity.entity_id)
|
||||||
|
for entity in self.hass.data[DOMAIN_DATA].entities
|
||||||
]
|
]
|
||||||
return BrowseMediaSource(
|
return BrowseMediaSource(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user