mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Use runtime_data in google_assistant_sdk (#144335)
This commit is contained in:
parent
2c34712069
commit
fbae79fab2
@ -10,7 +10,6 @@ from google.oauth2.credentials import Credentials
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
|
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
@ -26,15 +25,11 @@ from homeassistant.helpers.config_entry_oauth2_flow import (
|
|||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import (
|
from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
||||||
CONF_LANGUAGE_CODE,
|
|
||||||
DATA_MEM_STORAGE,
|
|
||||||
DATA_SESSION,
|
|
||||||
DOMAIN,
|
|
||||||
SUPPORTED_LANGUAGE_CODES,
|
|
||||||
)
|
|
||||||
from .helpers import (
|
from .helpers import (
|
||||||
GoogleAssistantSDKAudioView,
|
GoogleAssistantSDKAudioView,
|
||||||
|
GoogleAssistantSDKConfigEntry,
|
||||||
|
GoogleAssistantSDKRuntimeData,
|
||||||
InMemoryStorage,
|
InMemoryStorage,
|
||||||
async_send_text_commands,
|
async_send_text_commands,
|
||||||
best_matching_language_code,
|
best_matching_language_code,
|
||||||
@ -66,10 +61,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
|
||||||
|
) -> bool:
|
||||||
"""Set up Google Assistant SDK from a config entry."""
|
"""Set up Google Assistant SDK from a config entry."""
|
||||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
|
|
||||||
|
|
||||||
implementation = await async_get_config_entry_implementation(hass, entry)
|
implementation = await async_get_config_entry_implementation(hass, entry)
|
||||||
session = OAuth2Session(hass, entry, implementation)
|
session = OAuth2Session(hass, entry, implementation)
|
||||||
try:
|
try:
|
||||||
@ -82,23 +77,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
raise ConfigEntryNotReady from err
|
raise ConfigEntryNotReady from err
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
raise ConfigEntryNotReady from err
|
raise ConfigEntryNotReady from err
|
||||||
hass.data[DOMAIN][entry.entry_id][DATA_SESSION] = session
|
|
||||||
|
|
||||||
mem_storage = InMemoryStorage(hass)
|
mem_storage = InMemoryStorage(hass)
|
||||||
hass.data[DOMAIN][entry.entry_id][DATA_MEM_STORAGE] = mem_storage
|
|
||||||
hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage))
|
hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage))
|
||||||
|
|
||||||
await async_setup_service(hass)
|
await async_setup_service(hass)
|
||||||
|
|
||||||
|
entry.runtime_data = GoogleAssistantSDKRuntimeData(
|
||||||
|
session=session, mem_storage=mem_storage
|
||||||
|
)
|
||||||
agent = GoogleAssistantConversationAgent(hass, entry)
|
agent = GoogleAssistantConversationAgent(hass, entry)
|
||||||
conversation.async_set_agent(hass, entry, agent)
|
conversation.async_set_agent(hass, entry, agent)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(
|
||||||
|
hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
|
||||||
|
) -> bool:
|
||||||
"""Unload a config entry."""
|
"""Unload a config entry."""
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
|
||||||
if not hass.config_entries.async_loaded_entries(DOMAIN):
|
if not hass.config_entries.async_loaded_entries(DOMAIN):
|
||||||
for service_name in hass.services.async_services_for_domain(DOMAIN):
|
for service_name in hass.services.async_services_for_domain(DOMAIN):
|
||||||
hass.services.async_remove(DOMAIN, service_name)
|
hass.services.async_remove(DOMAIN, service_name)
|
||||||
@ -141,7 +138,9 @@ async def async_setup_service(hass: HomeAssistant) -> None:
|
|||||||
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||||
"""Google Assistant SDK conversation agent."""
|
"""Google Assistant SDK conversation agent."""
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
def __init__(
|
||||||
|
self, hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
|
||||||
|
) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
@ -161,7 +160,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
|||||||
if self.session:
|
if self.session:
|
||||||
session = self.session
|
session = self.session
|
||||||
else:
|
else:
|
||||||
session = self.hass.data[DOMAIN][self.entry.entry_id][DATA_SESSION]
|
session = self.entry.runtime_data.session
|
||||||
self.session = session
|
self.session = session
|
||||||
if not session.valid_token:
|
if not session.valid_token:
|
||||||
await session.async_ensure_token_valid()
|
await session.async_ensure_token_valid()
|
||||||
|
@ -8,17 +8,12 @@ from typing import Any
|
|||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import (
|
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult, OptionsFlow
|
||||||
SOURCE_REAUTH,
|
|
||||||
ConfigEntry,
|
|
||||||
ConfigFlowResult,
|
|
||||||
OptionsFlow,
|
|
||||||
)
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
|
|
||||||
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
||||||
from .helpers import default_language_code
|
from .helpers import GoogleAssistantSDKConfigEntry, default_language_code
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -77,7 +72,7 @@ class OAuth2FlowHandler(
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@callback
|
@callback
|
||||||
def async_get_options_flow(
|
def async_get_options_flow(
|
||||||
config_entry: ConfigEntry,
|
config_entry: GoogleAssistantSDKConfigEntry,
|
||||||
) -> OptionsFlow:
|
) -> OptionsFlow:
|
||||||
"""Create the options flow."""
|
"""Create the options flow."""
|
||||||
return OptionsFlowHandler()
|
return OptionsFlowHandler()
|
||||||
|
@ -8,9 +8,6 @@ DEFAULT_NAME: Final = "Google Assistant SDK"
|
|||||||
|
|
||||||
CONF_LANGUAGE_CODE: Final = "language_code"
|
CONF_LANGUAGE_CODE: Final = "language_code"
|
||||||
|
|
||||||
DATA_MEM_STORAGE: Final = "mem_storage"
|
|
||||||
DATA_SESSION: Final = "session"
|
|
||||||
|
|
||||||
# https://developers.google.com/assistant/sdk/reference/rpc/languages
|
# https://developers.google.com/assistant/sdk/reference/rpc/languages
|
||||||
SUPPORTED_LANGUAGE_CODES: Final = [
|
SUPPORTED_LANGUAGE_CODES: Final = [
|
||||||
"de-DE",
|
"de-DE",
|
||||||
|
@ -5,14 +5,15 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.diagnostics import async_redact_data
|
from homeassistant.components.diagnostics import async_redact_data
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .helpers import GoogleAssistantSDKConfigEntry
|
||||||
|
|
||||||
TO_REDACT = {"access_token", "refresh_token"}
|
TO_REDACT = {"access_token", "refresh_token"}
|
||||||
|
|
||||||
|
|
||||||
async def async_get_config_entry_diagnostics(
|
async def async_get_config_entry_diagnostics(
|
||||||
hass: HomeAssistant, entry: ConfigEntry
|
hass: HomeAssistant, entry: GoogleAssistantSDKConfigEntry
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return diagnostics for a config entry."""
|
"""Return diagnostics for a config entry."""
|
||||||
return async_redact_data(
|
return async_redact_data(
|
||||||
|
@ -28,13 +28,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
|
from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
|
|
||||||
from .const import (
|
from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
||||||
CONF_LANGUAGE_CODE,
|
|
||||||
DATA_MEM_STORAGE,
|
|
||||||
DATA_SESSION,
|
|
||||||
DOMAIN,
|
|
||||||
SUPPORTED_LANGUAGE_CODES,
|
|
||||||
)
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -49,6 +43,16 @@ DEFAULT_LANGUAGE_CODES = {
|
|||||||
"pt": "pt-BR",
|
"pt": "pt-BR",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GoogleAssistantSDKConfigEntry = ConfigEntry[GoogleAssistantSDKRuntimeData]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GoogleAssistantSDKRuntimeData:
|
||||||
|
"""Runtime data for Google Assistant SDK."""
|
||||||
|
|
||||||
|
session: OAuth2Session
|
||||||
|
mem_storage: InMemoryStorage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandResponse:
|
class CommandResponse:
|
||||||
@ -62,9 +66,9 @@ async def async_send_text_commands(
|
|||||||
) -> list[CommandResponse]:
|
) -> list[CommandResponse]:
|
||||||
"""Send text commands to Google Assistant Service."""
|
"""Send text commands to Google Assistant Service."""
|
||||||
# There can only be 1 entry (config_flow has single_instance_allowed)
|
# There can only be 1 entry (config_flow has single_instance_allowed)
|
||||||
entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0]
|
entry: GoogleAssistantSDKConfigEntry = hass.config_entries.async_entries(DOMAIN)[0]
|
||||||
|
|
||||||
session: OAuth2Session = hass.data[DOMAIN][entry.entry_id][DATA_SESSION]
|
session = entry.runtime_data.session
|
||||||
try:
|
try:
|
||||||
await session.async_ensure_token_valid()
|
await session.async_ensure_token_valid()
|
||||||
except aiohttp.ClientResponseError as err:
|
except aiohttp.ClientResponseError as err:
|
||||||
@ -84,11 +88,10 @@ async def async_send_text_commands(
|
|||||||
_LOGGER.debug("command: %s\nresponse: %s", command, text_response)
|
_LOGGER.debug("command: %s\nresponse: %s", command, text_response)
|
||||||
audio_response = resp[2]
|
audio_response = resp[2]
|
||||||
if media_players and audio_response:
|
if media_players and audio_response:
|
||||||
mem_storage: InMemoryStorage = hass.data[DOMAIN][entry.entry_id][
|
|
||||||
DATA_MEM_STORAGE
|
|
||||||
]
|
|
||||||
audio_url = GoogleAssistantSDKAudioView.url.format(
|
audio_url = GoogleAssistantSDKAudioView.url.format(
|
||||||
filename=mem_storage.store_and_get_identifier(audio_response)
|
filename=entry.runtime_data.mem_storage.store_and_get_identifier(
|
||||||
|
audio_response
|
||||||
|
)
|
||||||
)
|
)
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
DOMAIN_MP,
|
DOMAIN_MP,
|
||||||
|
@ -5,12 +5,15 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.notify import ATTR_TARGET, BaseNotificationService
|
from homeassistant.components.notify import ATTR_TARGET, BaseNotificationService
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from .const import CONF_LANGUAGE_CODE, DOMAIN
|
from .const import CONF_LANGUAGE_CODE, DOMAIN
|
||||||
from .helpers import async_send_text_commands, default_language_code
|
from .helpers import (
|
||||||
|
GoogleAssistantSDKConfigEntry,
|
||||||
|
async_send_text_commands,
|
||||||
|
default_language_code,
|
||||||
|
)
|
||||||
|
|
||||||
# https://support.google.com/assistant/answer/9071582?hl=en
|
# https://support.google.com/assistant/answer/9071582?hl=en
|
||||||
LANG_TO_BROADCAST_COMMAND = {
|
LANG_TO_BROADCAST_COMMAND = {
|
||||||
@ -59,7 +62,9 @@ class BroadcastNotificationService(BaseNotificationService):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# There can only be 1 entry (config_flow has single_instance_allowed)
|
# There can only be 1 entry (config_flow has single_instance_allowed)
|
||||||
entry: ConfigEntry = self.hass.config_entries.async_entries(DOMAIN)[0]
|
entry: GoogleAssistantSDKConfigEntry = self.hass.config_entries.async_entries(
|
||||||
|
DOMAIN
|
||||||
|
)[0]
|
||||||
language_code = entry.options.get(
|
language_code = entry.options.get(
|
||||||
CONF_LANGUAGE_CODE, default_language_code(self.hass)
|
CONF_LANGUAGE_CODE, default_language_code(self.hass)
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user