Remove store from google_assistant AbstractConfig (#109877)

* Remove store from google_assistant AbstractConfig

* Bump minor version of google_assistant store

* Fix test

* Improve comments

* Fix typo

* Refactor

* Update homeassistant/components/google_assistant/http.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Fix bug, add tests

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2024-02-12 19:24:21 +01:00 committed by GitHub
parent d78bb3894c
commit a51d3b4286
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 416 additions and 176 deletions

View File

@ -30,6 +30,7 @@ PREF_GOOGLE_DEFAULT_EXPOSE = "google_default_expose"
PREF_ALEXA_SETTINGS_VERSION = "alexa_settings_version" PREF_ALEXA_SETTINGS_VERSION = "alexa_settings_version"
PREF_GOOGLE_SETTINGS_VERSION = "google_settings_version" PREF_GOOGLE_SETTINGS_VERSION = "google_settings_version"
PREF_TTS_DEFAULT_VOICE = "tts_default_voice" PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
PREF_GOOGLE_CONNECTED = "google_connected"
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "female") DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "female")
DEFAULT_DISABLE_2FA = False DEFAULT_DISABLE_2FA = False
DEFAULT_ALEXA_REPORT_STATE = True DEFAULT_ALEXA_REPORT_STATE = True

View File

@ -258,17 +258,6 @@ class CloudGoogleConfig(AbstractConfig):
self._on_deinitialize.append(start.async_at_start(self.hass, on_hass_start)) self._on_deinitialize.append(start.async_at_start(self.hass, on_hass_start))
self._on_deinitialize.append(start.async_at_started(self.hass, on_hass_started)) self._on_deinitialize.append(start.async_at_started(self.hass, on_hass_started))
# Remove any stored user agent id that is not ours
remove_agent_user_ids = []
for agent_user_id in self._store.agent_user_ids:
if agent_user_id != self.agent_user_id:
remove_agent_user_ids.append(agent_user_id)
if remove_agent_user_ids:
_LOGGER.debug("remove non cloud agent_user_ids: %s", remove_agent_user_ids)
for agent_user_id in remove_agent_user_ids:
await self.async_disconnect_agent_user(agent_user_id)
self._on_deinitialize.append( self._on_deinitialize.append(
self._prefs.async_listen_updates(self._async_prefs_updated) self._prefs.async_listen_updates(self._async_prefs_updated)
) )
@ -339,7 +328,7 @@ class CloudGoogleConfig(AbstractConfig):
@property @property
def has_registered_user_agent(self) -> bool: def has_registered_user_agent(self) -> bool:
"""Return if we have a Agent User Id registered.""" """Return if we have a Agent User Id registered."""
return len(self._store.agent_user_ids) > 0 return len(self.async_get_agent_users()) > 0
def get_agent_user_id(self, context: Any) -> str: def get_agent_user_id(self, context: Any) -> str:
"""Get agent user ID making request.""" """Get agent user ID making request."""
@ -380,6 +369,30 @@ class CloudGoogleConfig(AbstractConfig):
resp = await cloud_api.async_google_actions_request_sync(self._cloud) resp = await cloud_api.async_google_actions_request_sync(self._cloud)
return resp.status return resp.status
async def async_connect_agent_user(self, agent_user_id: str) -> None:
"""Add a synced and known agent_user_id.
Called before sending a sync response to Google.
"""
await self._prefs.async_update(google_connected=True)
async def async_disconnect_agent_user(self, agent_user_id: str) -> None:
"""Turn off report state and disable further state reporting.
Called when:
- The user disconnects their account from Google.
- When the cloud configuration is initialized
- When sync entities fails with 404
"""
await self._prefs.async_update(google_connected=False)
@callback
def async_get_agent_users(self) -> tuple:
"""Return known agent users."""
if not self._prefs.google_connected or not self._cloud.username:
return ()
return (self._cloud.username,)
async def _async_prefs_updated(self, prefs: CloudPreferences) -> None: async def _async_prefs_updated(self, prefs: CloudPreferences) -> None:
"""Handle updated preferences.""" """Handle updated preferences."""
_LOGGER.debug("_async_prefs_updated") _LOGGER.debug("_async_prefs_updated")

View File

@ -8,6 +8,9 @@ import uuid
from homeassistant.auth.const import GROUP_ID_ADMIN from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.auth.models import User from homeassistant.auth.models import User
from homeassistant.components import webhook from homeassistant.components import webhook
from homeassistant.components.google_assistant.http import (
async_get_users as async_get_google_assistant_users,
)
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.helpers.typing import UNDEFINED, UndefinedType
@ -28,6 +31,7 @@ from .const import (
PREF_ENABLE_ALEXA, PREF_ENABLE_ALEXA,
PREF_ENABLE_GOOGLE, PREF_ENABLE_GOOGLE,
PREF_ENABLE_REMOTE, PREF_ENABLE_REMOTE,
PREF_GOOGLE_CONNECTED,
PREF_GOOGLE_DEFAULT_EXPOSE, PREF_GOOGLE_DEFAULT_EXPOSE,
PREF_GOOGLE_ENTITY_CONFIGS, PREF_GOOGLE_ENTITY_CONFIGS,
PREF_GOOGLE_LOCAL_WEBHOOK_ID, PREF_GOOGLE_LOCAL_WEBHOOK_ID,
@ -42,7 +46,7 @@ from .const import (
STORAGE_KEY = DOMAIN STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_VERSION_MINOR = 2 STORAGE_VERSION_MINOR = 3
ALEXA_SETTINGS_VERSION = 3 ALEXA_SETTINGS_VERSION = 3
GOOGLE_SETTINGS_VERSION = 3 GOOGLE_SETTINGS_VERSION = 3
@ -55,10 +59,27 @@ class CloudPreferencesStore(Store):
self, old_major_version: int, old_minor_version: int, old_data: dict[str, Any] self, old_major_version: int, old_minor_version: int, old_data: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Migrate to the new version.""" """Migrate to the new version."""
async def google_connected() -> bool:
"""Return True if our user is preset in the google_assistant store."""
# If we don't have a user, we can't be connected to Google
if not (cur_username := old_data.get(PREF_USERNAME)):
return False
# If our user is in the Google store, we're connected
return cur_username in await async_get_google_assistant_users(self.hass)
if old_major_version == 1: if old_major_version == 1:
if old_minor_version < 2: if old_minor_version < 2:
old_data.setdefault(PREF_ALEXA_SETTINGS_VERSION, 1) old_data.setdefault(PREF_ALEXA_SETTINGS_VERSION, 1)
old_data.setdefault(PREF_GOOGLE_SETTINGS_VERSION, 1) old_data.setdefault(PREF_GOOGLE_SETTINGS_VERSION, 1)
if old_minor_version < 3:
# Import settings from the google_assistant store which was previously
# shared between the cloud integration and manually configured Google
# assistant.
# In HA Core 2024.9, remove the import and also remove the Google
# assistant store if it's not been migrated by manual Google assistant
old_data.setdefault(PREF_GOOGLE_CONNECTED, await google_connected())
return old_data return old_data
@ -131,6 +152,7 @@ class CloudPreferences:
remote_domain: str | None | UndefinedType = UNDEFINED, remote_domain: str | None | UndefinedType = UNDEFINED,
alexa_settings_version: int | UndefinedType = UNDEFINED, alexa_settings_version: int | UndefinedType = UNDEFINED,
google_settings_version: int | UndefinedType = UNDEFINED, google_settings_version: int | UndefinedType = UNDEFINED,
google_connected: bool | UndefinedType = UNDEFINED,
) -> None: ) -> None:
"""Update user preferences.""" """Update user preferences."""
prefs = {**self._prefs} prefs = {**self._prefs}
@ -148,6 +170,7 @@ class CloudPreferences:
(PREF_GOOGLE_SETTINGS_VERSION, google_settings_version), (PREF_GOOGLE_SETTINGS_VERSION, google_settings_version),
(PREF_TTS_DEFAULT_VOICE, tts_default_voice), (PREF_TTS_DEFAULT_VOICE, tts_default_voice),
(PREF_REMOTE_DOMAIN, remote_domain), (PREF_REMOTE_DOMAIN, remote_domain),
(PREF_GOOGLE_CONNECTED, google_connected),
): ):
if value is not UNDEFINED: if value is not UNDEFINED:
prefs[key] = value prefs[key] = value
@ -241,6 +264,12 @@ class CloudPreferences:
google_enabled: bool = self._prefs[PREF_ENABLE_GOOGLE] google_enabled: bool = self._prefs[PREF_ENABLE_GOOGLE]
return google_enabled return google_enabled
@property
def google_connected(self) -> bool:
"""Return if Google is connected."""
google_connected: bool = self._prefs[PREF_GOOGLE_CONNECTED]
return google_connected
@property @property
def google_report_state(self) -> bool: def google_report_state(self) -> bool:
"""Return if Google report state is enabled.""" """Return if Google report state is enabled."""
@ -338,6 +367,7 @@ class CloudPreferences:
PREF_ENABLE_ALEXA: True, PREF_ENABLE_ALEXA: True,
PREF_ENABLE_GOOGLE: True, PREF_ENABLE_GOOGLE: True,
PREF_ENABLE_REMOTE: False, PREF_ENABLE_REMOTE: False,
PREF_GOOGLE_CONNECTED: False,
PREF_GOOGLE_DEFAULT_EXPOSE: DEFAULT_EXPOSED_DOMAINS, PREF_GOOGLE_DEFAULT_EXPOSE: DEFAULT_EXPOSED_DOMAINS,
PREF_GOOGLE_ENTITY_CONFIGS: {}, PREF_GOOGLE_ENTITY_CONFIGS: {},
PREF_GOOGLE_SETTINGS_VERSION: GOOGLE_SETTINGS_VERSION, PREF_GOOGLE_SETTINGS_VERSION: GOOGLE_SETTINGS_VERSION,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import gather from asyncio import gather
from collections.abc import Callable, Mapping from collections.abc import Callable, Collection, Mapping
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import lru_cache from functools import lru_cache
from http import HTTPStatus from http import HTTPStatus
@ -33,7 +33,6 @@ from homeassistant.helpers import (
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url from homeassistant.helpers.network import get_url
from homeassistant.helpers.redact import partial_redact from homeassistant.helpers.redact import partial_redact
from homeassistant.helpers.storage import Store
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import trait from . import trait
@ -46,8 +45,6 @@ from .const import (
ERR_FUNCTION_NOT_SUPPORTED, ERR_FUNCTION_NOT_SUPPORTED,
NOT_EXPOSE_LOCAL, NOT_EXPOSE_LOCAL,
SOURCE_LOCAL, SOURCE_LOCAL,
STORE_AGENT_USER_IDS,
STORE_GOOGLE_LOCAL_WEBHOOK_ID,
) )
from .data_redaction import async_redact_request_msg, async_redact_response_msg from .data_redaction import async_redact_request_msg, async_redact_response_msg
from .error import SmartHomeError from .error import SmartHomeError
@ -94,7 +91,6 @@ def _get_registry_entries(
class AbstractConfig(ABC): class AbstractConfig(ABC):
"""Hold the configuration for Google Assistant.""" """Hold the configuration for Google Assistant."""
_store: GoogleConfigStore
_unsub_report_state: Callable[[], None] | None = None _unsub_report_state: Callable[[], None] | None = None
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
@ -109,9 +105,6 @@ class AbstractConfig(ABC):
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
"""Perform async initialization of config.""" """Perform async initialization of config."""
self._store = GoogleConfigStore(self.hass)
await self._store.async_initialize()
if not self.enabled: if not self.enabled:
return return
@ -203,7 +196,7 @@ class AbstractConfig(ABC):
"""Send a state report to Google for all previously synced users.""" """Send a state report to Google for all previously synced users."""
jobs = [ jobs = [
self.async_report_state(message, agent_user_id) self.async_report_state(message, agent_user_id)
for agent_user_id in self._store.agent_user_ids for agent_user_id in self.async_get_agent_users()
] ]
await gather(*jobs) await gather(*jobs)
@ -235,13 +228,13 @@ class AbstractConfig(ABC):
async def async_sync_entities_all(self) -> int: async def async_sync_entities_all(self) -> int:
"""Sync all entities to Google for all registered agents.""" """Sync all entities to Google for all registered agents."""
if not self._store.agent_user_ids: if not self.async_get_agent_users():
return 204 return 204
res = await gather( res = await gather(
*( *(
self.async_sync_entities(agent_user_id) self.async_sync_entities(agent_user_id)
for agent_user_id in self._store.agent_user_ids for agent_user_id in self.async_get_agent_users()
) )
) )
return max(res, default=204) return max(res, default=204)
@ -262,13 +255,13 @@ class AbstractConfig(ABC):
self, event_id: str, payload: dict[str, Any] self, event_id: str, payload: dict[str, Any]
) -> HTTPStatus: ) -> HTTPStatus:
"""Sync notification to Google for all registered agents.""" """Sync notification to Google for all registered agents."""
if not self._store.agent_user_ids: if not self.async_get_agent_users():
return HTTPStatus.NO_CONTENT return HTTPStatus.NO_CONTENT
res = await gather( res = await gather(
*( *(
self.async_sync_notification(agent_user_id, event_id, payload) self.async_sync_notification(agent_user_id, event_id, payload)
for agent_user_id in self._store.agent_user_ids for agent_user_id in self.async_get_agent_users()
) )
) )
return max(res, default=HTTPStatus.NO_CONTENT) return max(res, default=HTTPStatus.NO_CONTENT)
@ -291,7 +284,7 @@ class AbstractConfig(ABC):
@callback @callback
def async_schedule_google_sync_all(self) -> None: def async_schedule_google_sync_all(self) -> None:
"""Schedule a sync for all registered agents.""" """Schedule a sync for all registered agents."""
for agent_user_id in self._store.agent_user_ids: for agent_user_id in self.async_get_agent_users():
self.async_schedule_google_sync(agent_user_id) self.async_schedule_google_sync(agent_user_id)
async def _async_request_sync_devices(self, agent_user_id: str) -> int: async def _async_request_sync_devices(self, agent_user_id: str) -> int:
@ -301,13 +294,14 @@ class AbstractConfig(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def async_connect_agent_user(self, agent_user_id: str): async def async_connect_agent_user(self, agent_user_id: str):
"""Add a synced and known agent_user_id. """Add a synced and known agent_user_id.
Called before sending a sync response to Google. Called before sending a sync response to Google.
""" """
self._store.add_agent_user_id(agent_user_id)
@abstractmethod
async def async_disconnect_agent_user(self, agent_user_id: str): async def async_disconnect_agent_user(self, agent_user_id: str):
"""Turn off report state and disable further state reporting. """Turn off report state and disable further state reporting.
@ -316,7 +310,11 @@ class AbstractConfig(ABC):
- When the cloud configuration is initialized - When the cloud configuration is initialized
- When sync entities fails with 404 - When sync entities fails with 404
""" """
self._store.pop_agent_user_id(agent_user_id)
@callback
@abstractmethod
def async_get_agent_users(self) -> Collection[str]:
"""Return known agent users."""
@callback @callback
def async_enable_local_sdk(self) -> None: def async_enable_local_sdk(self) -> None:
@ -330,7 +328,7 @@ class AbstractConfig(ABC):
self._local_sdk_active = False self._local_sdk_active = False
return return
for user_agent_id in self._store.agent_user_ids: for user_agent_id in self.async_get_agent_users():
if (webhook_id := self.get_local_webhook_id(user_agent_id)) is None: if (webhook_id := self.get_local_webhook_id(user_agent_id)) is None:
setup_successful = False setup_successful = False
break break
@ -375,7 +373,7 @@ class AbstractConfig(ABC):
if not self._local_sdk_active: if not self._local_sdk_active:
return return
for agent_user_id in self._store.agent_user_ids: for agent_user_id in self.async_get_agent_users():
webhook_id = self.get_local_webhook_id(agent_user_id) webhook_id = self.get_local_webhook_id(agent_user_id)
_LOGGER.debug( _LOGGER.debug(
"Unregister webhook handler %s for agent user id %s", "Unregister webhook handler %s for agent user id %s",
@ -454,65 +452,6 @@ class AbstractConfig(ABC):
return json_response(result) return json_response(result)
class GoogleConfigStore:
"""A configuration store for google assistant."""
_STORAGE_VERSION = 1
_STORAGE_KEY = DOMAIN
def __init__(self, hass):
"""Initialize a configuration store."""
self._hass = hass
self._store = Store(hass, self._STORAGE_VERSION, self._STORAGE_KEY)
self._data = None
async def async_initialize(self):
"""Finish initializing the ConfigStore."""
should_save_data = False
if (data := await self._store.async_load()) is None:
# if the store is not found create an empty one
# Note that the first request is always a cloud request,
# and that will store the correct agent user id to be used for local requests
data = {
STORE_AGENT_USER_IDS: {},
}
should_save_data = True
for agent_user_id, agent_user_data in data[STORE_AGENT_USER_IDS].items():
if STORE_GOOGLE_LOCAL_WEBHOOK_ID not in agent_user_data:
data[STORE_AGENT_USER_IDS][agent_user_id] = {
**agent_user_data,
STORE_GOOGLE_LOCAL_WEBHOOK_ID: webhook.async_generate_id(),
}
should_save_data = True
if should_save_data:
await self._store.async_save(data)
self._data = data
@property
def agent_user_ids(self):
"""Return a list of connected agent user_ids."""
return self._data[STORE_AGENT_USER_IDS]
@callback
def add_agent_user_id(self, agent_user_id):
"""Add an agent user id to store."""
if agent_user_id not in self._data[STORE_AGENT_USER_IDS]:
self._data[STORE_AGENT_USER_IDS][agent_user_id] = {
STORE_GOOGLE_LOCAL_WEBHOOK_ID: webhook.async_generate_id(),
}
self._store.async_delay_save(lambda: self._data, 1.0)
@callback
def pop_agent_user_id(self, agent_user_id):
"""Remove agent user id from store."""
if agent_user_id in self._data[STORE_AGENT_USER_IDS]:
self._data[STORE_AGENT_USER_IDS].pop(agent_user_id, None)
self._store.async_delay_save(lambda: self._data, 1.0)
class RequestData: class RequestData:
"""Hold data associated with a particular request.""" """Hold data associated with a particular request."""

View File

@ -11,14 +11,15 @@ from aiohttp import ClientError, ClientResponseError
from aiohttp.web import Request, Response from aiohttp.web import Request, Response
import jwt import jwt
from homeassistant.components import webhook
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant, callback
# Typing imports from homeassistant.exceptions import HomeAssistantError
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util import dt as dt_util from homeassistant.helpers.storage import STORAGE_DIR, Store
from homeassistant.util import dt as dt_util, json as json_util
from .const import ( from .const import (
CONF_CLIENT_EMAIL, CONF_CLIENT_EMAIL,
@ -30,12 +31,14 @@ from .const import (
CONF_REPORT_STATE, CONF_REPORT_STATE,
CONF_SECURE_DEVICES_PIN, CONF_SECURE_DEVICES_PIN,
CONF_SERVICE_ACCOUNT, CONF_SERVICE_ACCOUNT,
DOMAIN,
GOOGLE_ASSISTANT_API_ENDPOINT, GOOGLE_ASSISTANT_API_ENDPOINT,
HOMEGRAPH_SCOPE, HOMEGRAPH_SCOPE,
HOMEGRAPH_TOKEN_URL, HOMEGRAPH_TOKEN_URL,
REPORT_STATE_BASE_URL, REPORT_STATE_BASE_URL,
REQUEST_SYNC_BASE_URL, REQUEST_SYNC_BASE_URL,
SOURCE_CLOUD, SOURCE_CLOUD,
STORE_AGENT_USER_IDS,
STORE_GOOGLE_LOCAL_WEBHOOK_ID, STORE_GOOGLE_LOCAL_WEBHOOK_ID,
) )
from .helpers import AbstractConfig from .helpers import AbstractConfig
@ -78,6 +81,8 @@ async def _get_homegraph_token(
class GoogleConfig(AbstractConfig): class GoogleConfig(AbstractConfig):
"""Config for manual setup of Google.""" """Config for manual setup of Google."""
_store: GoogleConfigStore
def __init__(self, hass, config): def __init__(self, hass, config):
"""Initialize the config.""" """Initialize the config."""
super().__init__(hass) super().__init__(hass)
@ -87,6 +92,10 @@ class GoogleConfig(AbstractConfig):
async def async_initialize(self): async def async_initialize(self):
"""Perform async initialization of config.""" """Perform async initialization of config."""
# We need to initialize the store before calling super
self._store = GoogleConfigStore(self.hass)
await self._store.async_initialize()
await super().async_initialize() await super().async_initialize()
self.async_enable_local_sdk() self.async_enable_local_sdk()
@ -191,6 +200,28 @@ class GoogleConfig(AbstractConfig):
_LOGGER.error("No configuration for request_sync available") _LOGGER.error("No configuration for request_sync available")
return HTTPStatus.INTERNAL_SERVER_ERROR return HTTPStatus.INTERNAL_SERVER_ERROR
async def async_connect_agent_user(self, agent_user_id: str):
"""Add a synced and known agent_user_id.
Called before sending a sync response to Google.
"""
self._store.add_agent_user_id(agent_user_id)
async def async_disconnect_agent_user(self, agent_user_id: str):
"""Turn off report state and disable further state reporting.
Called when:
- The user disconnects their account from Google.
- When the cloud configuration is initialized
- When sync entities fails with 404
"""
self._store.pop_agent_user_id(agent_user_id)
@callback
def async_get_agent_users(self):
"""Return known agent users."""
return self._store.agent_user_ids
async def _async_update_token(self, force=False): async def _async_update_token(self, force=False):
if CONF_SERVICE_ACCOUNT not in self._config: if CONF_SERVICE_ACCOUNT not in self._config:
_LOGGER.error("Trying to get homegraph api token without service account") _LOGGER.error("Trying to get homegraph api token without service account")
@ -258,6 +289,71 @@ class GoogleConfig(AbstractConfig):
return await self.async_call_homegraph_api(REPORT_STATE_BASE_URL, data) return await self.async_call_homegraph_api(REPORT_STATE_BASE_URL, data)
class GoogleConfigStore:
"""A configuration store for google assistant."""
_STORAGE_VERSION = 1
_STORAGE_VERSION_MINOR = 2
_STORAGE_KEY = DOMAIN
_data: dict[str, Any]
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a configuration store."""
self._hass = hass
self._store: Store[dict[str, Any]] = Store(
hass,
self._STORAGE_VERSION,
self._STORAGE_KEY,
minor_version=self._STORAGE_VERSION_MINOR,
)
async def async_initialize(self) -> None:
"""Finish initializing the ConfigStore."""
should_save_data = False
if (data := await self._store.async_load()) is None:
# if the store is not found create an empty one
# Note that the first request is always a cloud request,
# and that will store the correct agent user id to be used for local requests
data = {
STORE_AGENT_USER_IDS: {},
}
should_save_data = True
for agent_user_id, agent_user_data in data[STORE_AGENT_USER_IDS].items():
if STORE_GOOGLE_LOCAL_WEBHOOK_ID not in agent_user_data:
data[STORE_AGENT_USER_IDS][agent_user_id] = {
**agent_user_data,
STORE_GOOGLE_LOCAL_WEBHOOK_ID: webhook.async_generate_id(),
}
should_save_data = True
if should_save_data:
await self._store.async_save(data)
self._data = data
@property
def agent_user_ids(self) -> dict[str, Any]:
"""Return a list of connected agent user_ids."""
return self._data[STORE_AGENT_USER_IDS]
@callback
def add_agent_user_id(self, agent_user_id: str) -> None:
"""Add an agent user id to store."""
if agent_user_id not in self._data[STORE_AGENT_USER_IDS]:
self._data[STORE_AGENT_USER_IDS][agent_user_id] = {
STORE_GOOGLE_LOCAL_WEBHOOK_ID: webhook.async_generate_id(),
}
self._store.async_delay_save(lambda: self._data, 1.0)
@callback
def pop_agent_user_id(self, agent_user_id: str) -> None:
"""Remove agent user id from store."""
if agent_user_id in self._data[STORE_AGENT_USER_IDS]:
self._data[STORE_AGENT_USER_IDS].pop(agent_user_id, None)
self._store.async_delay_save(lambda: self._data, 1.0)
class GoogleAssistantView(HomeAssistantView): class GoogleAssistantView(HomeAssistantView):
"""Handle Google Assistant requests.""" """Handle Google Assistant requests."""
@ -280,3 +376,26 @@ class GoogleAssistantView(HomeAssistantView):
SOURCE_CLOUD, SOURCE_CLOUD,
) )
return self.json(result) return self.json(result)
async def async_get_users(hass: HomeAssistant) -> list[str]:
"""Return stored users.
This is called by the cloud integration to import from the previously shared store.
"""
# pylint: disable-next=protected-access
path = hass.config.path(STORAGE_DIR, GoogleConfigStore._STORAGE_KEY)
try:
store_data = await hass.async_add_executor_job(json_util.load_json, path)
except HomeAssistantError:
return []
if (
not isinstance(store_data, dict)
or not (data := store_data.get("data"))
or not isinstance(data, dict)
or not (agent_user_ids := data.get("agent_user_ids"))
or not isinstance(agent_user_ids, dict)
):
return []
return list(agent_user_ids)

View File

@ -42,7 +42,7 @@ def mock_conf(hass, cloud_prefs):
GACTIONS_SCHEMA({}), GACTIONS_SCHEMA({}),
"mock-user-id", "mock-user-id",
cloud_prefs, cloud_prefs,
Mock(claims={"cognito:username": "abcdefghjkl"}), Mock(username="abcdefghjkl"),
) )
@ -104,9 +104,11 @@ async def test_sync_entities(mock_conf, hass: HomeAssistant, cloud_prefs) -> Non
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
await mock_conf.async_initialize() await mock_conf.async_initialize()
assert len(mock_conf.async_get_agent_users()) == 0
await mock_conf.async_connect_agent_user("mock-user-id") await mock_conf.async_connect_agent_user("mock-user-id")
assert len(mock_conf._store.agent_user_ids) == 1 assert len(mock_conf.async_get_agent_users()) == 1
with patch( with patch(
"hass_nabucasa.cloud_api.async_google_actions_request_sync", "hass_nabucasa.cloud_api.async_google_actions_request_sync",
@ -115,7 +117,7 @@ async def test_sync_entities(mock_conf, hass: HomeAssistant, cloud_prefs) -> Non
assert ( assert (
await mock_conf.async_sync_entities("mock-user-id") == HTTPStatus.NOT_FOUND await mock_conf.async_sync_entities("mock-user-id") == HTTPStatus.NOT_FOUND
) )
assert len(mock_conf._store.agent_user_ids) == 0 assert len(mock_conf.async_get_agent_users()) == 0
assert len(mock_request_sync.mock_calls) == 1 assert len(mock_request_sync.mock_calls) == 1
@ -144,7 +146,7 @@ async def test_google_update_expose_trigger_sync(
GACTIONS_SCHEMA({}), GACTIONS_SCHEMA({}),
"mock-user-id", "mock-user-id",
cloud_prefs, cloud_prefs,
Mock(claims={"cognito:username": "abcdefghjkl"}), Mock(username="abcdefghjkl"),
) )
await config.async_initialize() await config.async_initialize()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
@ -271,6 +273,7 @@ async def test_google_device_registry_sync(
await config.async_initialize() await config.async_initialize()
await hass.async_block_till_done() await hass.async_block_till_done()
await config.async_connect_agent_user("mock-user-id") await config.async_connect_agent_user("mock-user-id")
await hass.async_block_till_done()
with patch.object(config, "async_schedule_google_sync_all") as mock_sync: with patch.object(config, "async_schedule_google_sync_all") as mock_sync:
# Device registry updated with non-relevant changes # Device registry updated with non-relevant changes
@ -326,7 +329,6 @@ async def test_sync_google_when_started(
) )
with patch.object(config, "async_sync_entities_all") as mock_sync: with patch.object(config, "async_sync_entities_all") as mock_sync:
await config.async_initialize() await config.async_initialize()
await config.async_connect_agent_user("mock-user-id")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_sync.mock_calls) == 1 assert len(mock_sync.mock_calls) == 1
@ -341,7 +343,6 @@ async def test_sync_google_on_home_assistant_start(
hass.set_state(CoreState.starting) hass.set_state(CoreState.starting)
with patch.object(config, "async_sync_entities_all") as mock_sync: with patch.object(config, "async_sync_entities_all") as mock_sync:
await config.async_initialize() await config.async_initialize()
await config.async_connect_agent_user("mock-user-id")
assert len(mock_sync.mock_calls) == 0 assert len(mock_sync.mock_calls) == 0
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)

View File

@ -2,6 +2,8 @@
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest
from homeassistant.auth.const import GROUP_ID_ADMIN from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.components.cloud.prefs import STORAGE_KEY, CloudPreferences from homeassistant.components.cloud.prefs import STORAGE_KEY, CloudPreferences
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -98,3 +100,25 @@ async def test_setup_remove_cloud_user(
assert cloud_user2 assert cloud_user2
assert cloud_user2.groups[0].id == GROUP_ID_ADMIN assert cloud_user2.groups[0].id == GROUP_ID_ADMIN
assert cloud_user2.id != cloud_user.id assert cloud_user2.id != cloud_user.id
@pytest.mark.parametrize(
("google_assistant_users", "google_connected"),
[([], False), (["cloud-user"], True), (["other-user"], False)],
)
async def test_import_google_assistant_settings(
hass: HomeAssistant,
hass_storage: dict[str, Any],
google_assistant_users: list[str],
google_connected: bool,
) -> None:
"""Test importing from the google assistant store."""
hass_storage[STORAGE_KEY] = {"version": 1, "data": {"username": "cloud-user"}}
with patch(
"homeassistant.components.cloud.prefs.async_get_google_assistant_users"
) as mock_get_users:
mock_get_users.return_value = google_assistant_users
prefs = CloudPreferences(hass)
await prefs.async_initialize()
assert prefs.google_connected == google_connected

View File

@ -6,7 +6,7 @@ from homeassistant.components.google_assistant import helpers, http
def mock_google_config_store(agent_user_ids=None): def mock_google_config_store(agent_user_ids=None):
"""Fake a storage for google assistant.""" """Fake a storage for google assistant."""
store = MagicMock(spec=helpers.GoogleConfigStore) store = MagicMock(spec=http.GoogleConfigStore)
if agent_user_ids is not None: if agent_user_ids is not None:
store.agent_user_ids = agent_user_ids store.agent_user_ids = agent_user_ids
else: else:

View File

@ -1,7 +1,6 @@
"""Test Google Assistant helpers.""" """Test Google Assistant helpers."""
from datetime import timedelta from datetime import timedelta
from http import HTTPStatus from http import HTTPStatus
from typing import Any
from unittest.mock import Mock, call, patch from unittest.mock import Mock, call, patch
import pytest import pytest
@ -23,12 +22,7 @@ from homeassistant.util import dt as dt_util
from . import MockConfig from . import MockConfig
from tests.common import ( from tests.common import MockConfigEntry, async_capture_events, async_mock_service
MockConfigEntry,
async_capture_events,
async_fire_time_changed,
async_mock_service,
)
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@ -274,72 +268,6 @@ async def test_config_local_sdk_if_ssl_enabled(
assert await resp.read() == b"" assert await resp.read() == b""
async def test_agent_user_id_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test a disconnect message."""
hass_storage["google_assistant"] = {
"version": 1,
"minor_version": 1,
"key": "google_assistant",
"data": {
"agent_user_ids": {
"agent_1": {
"local_webhook_id": "test_webhook",
}
},
},
}
store = helpers.GoogleConfigStore(hass)
await store.async_initialize()
assert hass_storage["google_assistant"] == {
"version": 1,
"minor_version": 1,
"key": "google_assistant",
"data": {
"agent_user_ids": {
"agent_1": {
"local_webhook_id": "test_webhook",
}
},
},
}
async def _check_after_delay(data):
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=2))
await hass.async_block_till_done()
assert (
list(hass_storage["google_assistant"]["data"]["agent_user_ids"].keys())
== data
)
store.add_agent_user_id("agent_2")
await _check_after_delay(["agent_1", "agent_2"])
store.pop_agent_user_id("agent_1")
await _check_after_delay(["agent_2"])
hass_storage["google_assistant"] = {
"version": 1,
"minor_version": 1,
"key": "google_assistant",
"data": {
"agent_user_ids": {"agent_1": {}},
},
}
store = helpers.GoogleConfigStore(hass)
await store.async_initialize()
assert (
STORE_GOOGLE_LOCAL_WEBHOOK_ID
in hass_storage["google_assistant"]["data"]["agent_user_ids"]["agent_1"]
)
async def test_agent_user_id_connect() -> None: async def test_agent_user_id_connect() -> None:
"""Test the connection and disconnection of users.""" """Test the connection and disconnection of users."""
config = MockConfig() config = MockConfig()

View File

@ -1,10 +1,13 @@
"""Test Google http services.""" """Test Google http services."""
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from http import HTTPStatus from http import HTTPStatus
import json
import os
from typing import Any from typing import Any
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
from uuid import uuid4 from uuid import uuid4
import py
import pytest import pytest
from homeassistant.components.google_assistant import GOOGLE_ASSISTANT_SCHEMA from homeassistant.components.google_assistant import GOOGLE_ASSISTANT_SCHEMA
@ -18,14 +21,22 @@ from homeassistant.components.google_assistant.const import (
) )
from homeassistant.components.google_assistant.http import ( from homeassistant.components.google_assistant.http import (
GoogleConfig, GoogleConfig,
GoogleConfigStore,
_get_homegraph_jwt, _get_homegraph_jwt,
_get_homegraph_token, _get_homegraph_token,
async_get_users,
) )
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant, State
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from tests.common import async_capture_events, async_mock_service from tests.common import (
async_capture_events,
async_fire_time_changed,
async_mock_service,
async_test_home_assistant,
)
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@ -469,3 +480,177 @@ async def test_async_enable_local_sdk(
"Cannot process request for webhook **REDACTED** as no linked agent user is found:" "Cannot process request for webhook **REDACTED** as no linked agent user is found:"
in caplog.text in caplog.text
) )
async def test_agent_user_id_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test a disconnect message."""
hass_storage["google_assistant"] = {
"version": 1,
"minor_version": 1,
"key": "google_assistant",
"data": {
"agent_user_ids": {
"agent_1": {
"local_webhook_id": "test_webhook",
}
},
},
}
store = GoogleConfigStore(hass)
await store.async_initialize()
assert hass_storage["google_assistant"] == {
"version": 1,
"minor_version": 2,
"key": "google_assistant",
"data": {
"agent_user_ids": {
"agent_1": {
"local_webhook_id": "test_webhook",
}
},
},
}
async def _check_after_delay(data):
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=2))
await hass.async_block_till_done()
assert (
list(hass_storage["google_assistant"]["data"]["agent_user_ids"].keys())
== data
)
store.add_agent_user_id("agent_2")
await _check_after_delay(["agent_1", "agent_2"])
store.pop_agent_user_id("agent_1")
await _check_after_delay(["agent_2"])
hass_storage["google_assistant"] = {
"version": 1,
"minor_version": 2,
"key": "google_assistant",
"data": {
"agent_user_ids": {"agent_1": {}},
},
}
store = GoogleConfigStore(hass)
await store.async_initialize()
assert (
STORE_GOOGLE_LOCAL_WEBHOOK_ID
in hass_storage["google_assistant"]["data"]["agent_user_ids"]["agent_1"]
)
async def test_async_get_users_no_store(hass: HomeAssistant) -> None:
"""Test async_get_users when there is no store."""
assert await async_get_users(hass) == []
async def test_async_get_users_from_store(tmpdir: py.path.local) -> None:
"""Test async_get_users from a store.
This test ensures we can load from data saved by GoogleConfigStore.
"""
async with async_test_home_assistant() as hass:
hass.config.config_dir = await hass.async_add_executor_job(
tmpdir.mkdir, "temp_storage"
)
store = GoogleConfigStore(hass)
await store.async_initialize()
store.add_agent_user_id("agent_1")
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=2))
await hass.async_block_till_done()
assert await async_get_users(hass) == ["agent_1"]
VALID_STORE_DATA = json.dumps(
{
"version": 1,
"minor_version": 2,
"key": "google_assistant",
"data": {
"agent_user_ids": {"agent_1": {}},
},
}
)
NO_DATA = json.dumps(
{
"version": 1,
"minor_version": 2,
"key": "google_assistant",
}
)
DATA_NOT_DICT = json.dumps(
{
"version": 1,
"minor_version": 2,
"key": "google_assistant",
"data": "hello",
}
)
NO_AGENT_USER_IDS = json.dumps(
{
"version": 1,
"minor_version": 2,
"key": "google_assistant",
"data": {},
}
)
AGENT_USER_IDS_NOT_DICT = json.dumps(
{
"version": 1,
"minor_version": 2,
"key": "google_assistant",
"data": {
"agent_user_ids": "hello",
},
}
)
@pytest.mark.parametrize(
("store_data", "expected_users"),
[
(VALID_STORE_DATA, ["agent_1"]),
("", []),
("not_a_dict", []),
(NO_DATA, []),
(DATA_NOT_DICT, []),
(NO_AGENT_USER_IDS, []),
(AGENT_USER_IDS_NOT_DICT, []),
],
)
async def test_async_get_users(
tmpdir: py.path.local, store_data: str, expected_users: list[str]
) -> None:
"""Test async_get_users from stored JSON data."""
async with async_test_home_assistant() as hass:
hass.config.config_dir = await hass.async_add_executor_job(
tmpdir.mkdir, "temp_storage"
)
path = hass.config.config_dir / ".storage" / GoogleConfigStore._STORAGE_KEY
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
f.write(store_data)
assert await async_get_users(hass) == expected_users
await hass.async_stop()