mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Use HassKey in stt (#126335)
This commit is contained in:
parent
91c1e75c00
commit
0299fa1b68
@ -32,6 +32,7 @@ from homeassistant.util import dt as dt_util, language as language_util
|
||||
from .const import (
|
||||
DATA_PROVIDERS,
|
||||
DOMAIN,
|
||||
DOMAIN_DATA,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
AudioCodecs,
|
||||
@ -72,11 +73,9 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain or entity id of the default engine."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
|
||||
default_entity_id: str | None = None
|
||||
|
||||
for entity in component.entities:
|
||||
for entity in hass.data[DOMAIN_DATA].entities:
|
||||
if entity.platform and entity.platform.platform_name == "cloud":
|
||||
return entity.entity_id
|
||||
|
||||
@ -91,9 +90,7 @@ def async_get_speech_to_text_entity(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> SpeechToTextEntity | None:
|
||||
"""Return stt entity."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
|
||||
return component.get_entity(entity_id)
|
||||
return hass.data[DOMAIN_DATA].get_entity(entity_id)
|
||||
|
||||
|
||||
@callback
|
||||
@ -111,13 +108,11 @@ def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
|
||||
"""Return a set with the union of languages supported by stt engines."""
|
||||
languages = set()
|
||||
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
|
||||
for entity in component.entities:
|
||||
for entity in hass.data[DOMAIN_DATA].entities:
|
||||
for language_tag in entity.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
for engine in legacy_providers.values():
|
||||
for engine in hass.data[DATA_PROVIDERS].values():
|
||||
for language_tag in engine.supported_languages:
|
||||
languages.add(language_tag)
|
||||
|
||||
@ -128,7 +123,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up STT."""
|
||||
websocket_api.async_register_command(hass, websocket_list_engines)
|
||||
|
||||
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
|
||||
component = hass.data[DOMAIN_DATA] = EntityComponent[SpeechToTextEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
|
||||
@ -150,14 +145,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
return await hass.data[DOMAIN_DATA].async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
||||
return await hass.data[DOMAIN_DATA].async_unload_entry(entry)
|
||||
|
||||
|
||||
class SpeechToTextEntity(RestoreEntity):
|
||||
@ -426,15 +419,12 @@ def websocket_list_engines(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List speech-to-text engines and, optionally, if they support a given language."""
|
||||
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
|
||||
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
|
||||
|
||||
country = msg.get("country")
|
||||
language = msg.get("language")
|
||||
providers = []
|
||||
provider_info: dict[str, Any]
|
||||
|
||||
for entity in component.entities:
|
||||
for entity in hass.data[DOMAIN_DATA].entities:
|
||||
provider_info = {
|
||||
"engine_id": entity.entity_id,
|
||||
"supported_languages": entity.supported_languages,
|
||||
@ -445,7 +435,7 @@ def websocket_list_engines(
|
||||
)
|
||||
providers.append(provider_info)
|
||||
|
||||
for engine_id, provider in legacy_providers.items():
|
||||
for engine_id, provider in hass.data[DATA_PROVIDERS].items():
|
||||
provider_info = {
|
||||
"engine_id": engine_id,
|
||||
"name": provider.name,
|
||||
|
@ -1,9 +1,21 @@
|
||||
"""STT constante."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from . import SpeechToTextEntity
|
||||
from .legacy import Provider
|
||||
|
||||
DOMAIN = "stt"
|
||||
DATA_PROVIDERS = f"{DOMAIN}_providers"
|
||||
DOMAIN_DATA: HassKey[EntityComponent[SpeechToTextEntity]] = HassKey(DOMAIN)
|
||||
DATA_PROVIDERS: HassKey[dict[str, Provider]] = HassKey(f"{DOMAIN}_providers")
|
||||
|
||||
|
||||
class AudioCodecs(str, Enum):
|
||||
|
@ -34,7 +34,8 @@ _LOGGER = logging.getLogger(__name__)
|
||||
@callback
|
||||
def async_default_provider(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain of the default provider."""
|
||||
return next(iter(hass.data[DATA_PROVIDERS]), None)
|
||||
providers = hass.data[DATA_PROVIDERS]
|
||||
return next(iter(providers), None)
|
||||
|
||||
|
||||
@callback
|
||||
@ -42,7 +43,7 @@ def async_get_provider(
|
||||
hass: HomeAssistant, domain: str | None = None
|
||||
) -> Provider | None:
|
||||
"""Return provider."""
|
||||
providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
|
||||
providers = hass.data[DATA_PROVIDERS]
|
||||
if domain:
|
||||
return providers.get(domain)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user