Use HassKey in stt (#126335)

This commit is contained in:
epenet 2024-09-21 11:34:28 +02:00 committed by GitHub
parent 91c1e75c00
commit 0299fa1b68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 23 deletions

View File

@ -32,6 +32,7 @@ from homeassistant.util import dt as dt_util, language as language_util
from .const import (
DATA_PROVIDERS,
DOMAIN,
DOMAIN_DATA,
AudioBitRates,
AudioChannels,
AudioCodecs,
@ -72,11 +73,9 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
default_entity_id: str | None = None
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
if entity.platform and entity.platform.platform_name == "cloud":
return entity.entity_id
@ -91,9 +90,7 @@ def async_get_speech_to_text_entity(
hass: HomeAssistant, entity_id: str
) -> SpeechToTextEntity | None:
"""Return stt entity."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return component.get_entity(entity_id)
return hass.data[DOMAIN_DATA].get_entity(entity_id)
@callback
@ -111,13 +108,11 @@ def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by stt engines."""
languages = set()
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
for language_tag in entity.supported_languages:
languages.add(language_tag)
for engine in legacy_providers.values():
for engine in hass.data[DATA_PROVIDERS].values():
for language_tag in engine.supported_languages:
languages.add(language_tag)
@ -128,7 +123,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
websocket_api.async_register_command(hass, websocket_list_engines)
component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
component = hass.data[DOMAIN_DATA] = EntityComponent[SpeechToTextEntity](
_LOGGER, DOMAIN, hass
)
@ -150,14 +145,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
return await hass.data[DOMAIN_DATA].async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)
return await hass.data[DOMAIN_DATA].async_unload_entry(entry)
class SpeechToTextEntity(RestoreEntity):
@ -426,15 +419,12 @@ def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List speech-to-text engines and, optionally, if they support a given language."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
country = msg.get("country")
language = msg.get("language")
providers = []
provider_info: dict[str, Any]
for entity in component.entities:
for entity in hass.data[DOMAIN_DATA].entities:
provider_info = {
"engine_id": entity.entity_id,
"supported_languages": entity.supported_languages,
@ -445,7 +435,7 @@ def websocket_list_engines(
)
providers.append(provider_info)
for engine_id, provider in legacy_providers.items():
for engine_id, provider in hass.data[DATA_PROVIDERS].items():
provider_info = {
"engine_id": engine_id,
"name": provider.name,

View File

@ -1,9 +1,21 @@
"""STT constante."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.helpers.entity_component import EntityComponent
from . import SpeechToTextEntity
from .legacy import Provider
DOMAIN = "stt"
DATA_PROVIDERS = f"{DOMAIN}_providers"
DOMAIN_DATA: HassKey[EntityComponent[SpeechToTextEntity]] = HassKey(DOMAIN)
DATA_PROVIDERS: HassKey[dict[str, Provider]] = HassKey(f"{DOMAIN}_providers")
class AudioCodecs(str, Enum):

View File

@ -34,7 +34,8 @@ _LOGGER = logging.getLogger(__name__)
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider."""
return next(iter(hass.data[DATA_PROVIDERS]), None)
providers = hass.data[DATA_PROVIDERS]
return next(iter(providers), None)
@callback
@ -42,7 +43,7 @@ def async_get_provider(
hass: HomeAssistant, domain: str | None = None
) -> Provider | None:
"""Return provider."""
providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]
providers = hass.data[DATA_PROVIDERS]
if domain:
return providers.get(domain)