Use runtime_data in google_assistant_sdk (#144335)

This commit is contained in:
epenet 2025-05-06 15:52:00 +02:00 committed by GitHub
parent 2c34712069
commit fbae79fab2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 46 additions and 46 deletions

View File

@ -10,7 +10,6 @@ from google.oauth2.credentials import Credentials
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
@ -26,15 +25,11 @@ from homeassistant.helpers.config_entry_oauth2_flow import (
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
from .helpers import ( from .helpers import (
GoogleAssistantSDKAudioView, GoogleAssistantSDKAudioView,
GoogleAssistantSDKConfigEntry,
GoogleAssistantSDKRuntimeData,
InMemoryStorage, InMemoryStorage,
async_send_text_commands, async_send_text_commands,
best_matching_language_code, best_matching_language_code,
@ -66,10 +61,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(
hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
) -> bool:
"""Set up Google Assistant SDK from a config entry.""" """Set up Google Assistant SDK from a config entry."""
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
implementation = await async_get_config_entry_implementation(hass, entry) implementation = await async_get_config_entry_implementation(hass, entry)
session = OAuth2Session(hass, entry, implementation) session = OAuth2Session(hass, entry, implementation)
try: try:
@ -82,23 +77,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
except aiohttp.ClientError as err: except aiohttp.ClientError as err:
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
hass.data[DOMAIN][entry.entry_id][DATA_SESSION] = session
mem_storage = InMemoryStorage(hass) mem_storage = InMemoryStorage(hass)
hass.data[DOMAIN][entry.entry_id][DATA_MEM_STORAGE] = mem_storage
hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage)) hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage))
await async_setup_service(hass) await async_setup_service(hass)
entry.runtime_data = GoogleAssistantSDKRuntimeData(
session=session, mem_storage=mem_storage
)
agent = GoogleAssistantConversationAgent(hass, entry) agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, entry, agent) conversation.async_set_agent(hass, entry, agent)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(
hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
hass.data[DOMAIN].pop(entry.entry_id)
if not hass.config_entries.async_loaded_entries(DOMAIN): if not hass.config_entries.async_loaded_entries(DOMAIN):
for service_name in hass.services.async_services_for_domain(DOMAIN): for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name) hass.services.async_remove(DOMAIN, service_name)
@ -141,7 +138,9 @@ async def async_setup_service(hass: HomeAssistant) -> None:
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent.""" """Google Assistant SDK conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: def __init__(
self, hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.hass = hass self.hass = hass
self.entry = entry self.entry = entry
@ -161,7 +160,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
if self.session: if self.session:
session = self.session session = self.session
else: else:
session = self.hass.data[DOMAIN][self.entry.entry_id][DATA_SESSION] session = self.entry.runtime_data.session
self.session = session self.session = session
if not session.valid_token: if not session.valid_token:
await session.async_ensure_token_valid() await session.async_ensure_token_valid()

View File

@ -8,17 +8,12 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ( from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult, OptionsFlow
SOURCE_REAUTH,
ConfigEntry,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .helpers import default_language_code from .helpers import GoogleAssistantSDKConfigEntry, default_language_code
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -77,7 +72,7 @@ class OAuth2FlowHandler(
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(
config_entry: ConfigEntry, config_entry: GoogleAssistantSDKConfigEntry,
) -> OptionsFlow: ) -> OptionsFlow:
"""Create the options flow.""" """Create the options flow."""
return OptionsFlowHandler() return OptionsFlowHandler()

View File

@ -8,9 +8,6 @@ DEFAULT_NAME: Final = "Google Assistant SDK"
CONF_LANGUAGE_CODE: Final = "language_code" CONF_LANGUAGE_CODE: Final = "language_code"
DATA_MEM_STORAGE: Final = "mem_storage"
DATA_SESSION: Final = "session"
# https://developers.google.com/assistant/sdk/reference/rpc/languages # https://developers.google.com/assistant/sdk/reference/rpc/languages
SUPPORTED_LANGUAGE_CODES: Final = [ SUPPORTED_LANGUAGE_CODES: Final = [
"de-DE", "de-DE",

View File

@ -5,14 +5,15 @@ from __future__ import annotations
from typing import Any from typing import Any
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .helpers import GoogleAssistantSDKConfigEntry
TO_REDACT = {"access_token", "refresh_token"} TO_REDACT = {"access_token", "refresh_token"}
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
return async_redact_data( return async_redact_data(

View File

@ -28,13 +28,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from .const import ( from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -49,6 +43,16 @@ DEFAULT_LANGUAGE_CODES = {
"pt": "pt-BR", "pt": "pt-BR",
} }
type GoogleAssistantSDKConfigEntry = ConfigEntry[GoogleAssistantSDKRuntimeData]
@dataclass
class GoogleAssistantSDKRuntimeData:
"""Runtime data for Google Assistant SDK."""
session: OAuth2Session
mem_storage: InMemoryStorage
@dataclass @dataclass
class CommandResponse: class CommandResponse:
@ -62,9 +66,9 @@ async def async_send_text_commands(
) -> list[CommandResponse]: ) -> list[CommandResponse]:
"""Send text commands to Google Assistant Service.""" """Send text commands to Google Assistant Service."""
# There can only be 1 entry (config_flow has single_instance_allowed) # There can only be 1 entry (config_flow has single_instance_allowed)
entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0] entry: GoogleAssistantSDKConfigEntry = hass.config_entries.async_entries(DOMAIN)[0]
session: OAuth2Session = hass.data[DOMAIN][entry.entry_id][DATA_SESSION] session = entry.runtime_data.session
try: try:
await session.async_ensure_token_valid() await session.async_ensure_token_valid()
except aiohttp.ClientResponseError as err: except aiohttp.ClientResponseError as err:
@ -84,11 +88,10 @@ async def async_send_text_commands(
_LOGGER.debug("command: %s\nresponse: %s", command, text_response) _LOGGER.debug("command: %s\nresponse: %s", command, text_response)
audio_response = resp[2] audio_response = resp[2]
if media_players and audio_response: if media_players and audio_response:
mem_storage: InMemoryStorage = hass.data[DOMAIN][entry.entry_id][
DATA_MEM_STORAGE
]
audio_url = GoogleAssistantSDKAudioView.url.format( audio_url = GoogleAssistantSDKAudioView.url.format(
filename=mem_storage.store_and_get_identifier(audio_response) filename=entry.runtime_data.mem_storage.store_and_get_identifier(
audio_response
)
) )
await hass.services.async_call( await hass.services.async_call(
DOMAIN_MP, DOMAIN_MP,

View File

@ -5,12 +5,15 @@ from __future__ import annotations
from typing import Any from typing import Any
from homeassistant.components.notify import ATTR_TARGET, BaseNotificationService from homeassistant.components.notify import ATTR_TARGET, BaseNotificationService
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import CONF_LANGUAGE_CODE, DOMAIN from .const import CONF_LANGUAGE_CODE, DOMAIN
from .helpers import async_send_text_commands, default_language_code from .helpers import (
GoogleAssistantSDKConfigEntry,
async_send_text_commands,
default_language_code,
)
# https://support.google.com/assistant/answer/9071582?hl=en # https://support.google.com/assistant/answer/9071582?hl=en
LANG_TO_BROADCAST_COMMAND = { LANG_TO_BROADCAST_COMMAND = {
@ -59,7 +62,9 @@ class BroadcastNotificationService(BaseNotificationService):
return return
# There can only be 1 entry (config_flow has single_instance_allowed) # There can only be 1 entry (config_flow has single_instance_allowed)
entry: ConfigEntry = self.hass.config_entries.async_entries(DOMAIN)[0] entry: GoogleAssistantSDKConfigEntry = self.hass.config_entries.async_entries(
DOMAIN
)[0]
language_code = entry.options.get( language_code = entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass) CONF_LANGUAGE_CODE, default_language_code(self.hass)
) )