diff --git a/homeassistant/components/nest/__init__.py b/homeassistant/components/nest/__init__.py index 37901d060e1..0933f10e6ce 100644 --- a/homeassistant/components/nest/__init__.py +++ b/homeassistant/components/nest/__init__.py @@ -127,10 +127,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: if CONF_PROJECT_ID not in config[DOMAIN]: return await async_setup_legacy(hass, config) - if CONF_SUBSCRIBER_ID not in config[DOMAIN]: - _LOGGER.error("Configuration option 'subscriber_id' required") - return False - # For setup of ConfigEntry below hass.data[DOMAIN][DATA_NEST_CONFIG] = config[DOMAIN] project_id = config[DOMAIN][CONF_PROJECT_ID] @@ -195,9 +191,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return await async_setup_legacy_entry(hass, entry) subscriber = await api.new_subscriber(hass, entry) + if not subscriber: + return False + callback = SignalUpdateCallback(hass) subscriber.set_update_callback(callback.async_handle_event) - try: await subscriber.start_async() except AuthException as err: @@ -245,3 +243,24 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data[DOMAIN].pop(DATA_NEST_UNAVAILABLE, None) return unload_ok + + +async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None: + """Handle removal of pubsub subscriptions created during config flow.""" + if DATA_SDM not in entry.data or CONF_SUBSCRIBER_ID not in entry.data: + return + + subscriber = await api.new_subscriber(hass, entry) + if not subscriber: + return + _LOGGER.debug("Deleting subscriber '%s'", subscriber.subscriber_id) + try: + await subscriber.delete_subscription() + except GoogleNestException as err: + _LOGGER.warning( + "Unable to delete subscription '%s'; Will be automatically cleaned up by cloud console: %s", + subscriber.subscriber_id, + err, + ) + finally: + subscriber.stop_async() diff --git a/homeassistant/components/nest/api.py b/homeassistant/components/nest/api.py index 17b473dbeaa..3934b0b3cf1 100644 --- a/homeassistant/components/nest/api.py +++ b/homeassistant/components/nest/api.py @@ -1,6 +1,9 @@ """API for Google Nest Device Access bound to Home Assistant OAuth.""" +from __future__ import annotations + import datetime +import logging from typing import cast from aiohttp import ClientSession @@ -23,7 +26,7 @@ from .const import ( SDM_SCOPES, ) -# See https://developers.google.com/nest/device-access/registration +_LOGGER = logging.getLogger(__name__) class AsyncConfigEntryAuth(AbstractAuth): @@ -71,14 +74,31 @@ class AsyncConfigEntryAuth(AbstractAuth): async def new_subscriber( hass: HomeAssistant, entry: ConfigEntry -) -> GoogleNestSubscriber: +) -> GoogleNestSubscriber | None: """Create a GoogleNestSubscriber.""" implementation = ( await config_entry_oauth2_flow.async_get_config_entry_implementation( hass, entry ) ) + config = hass.data[DOMAIN][DATA_NEST_CONFIG] + if not ( + subscriber_id := entry.data.get( + CONF_SUBSCRIBER_ID, config.get(CONF_SUBSCRIBER_ID) + ) + ): + _LOGGER.error("Configuration option 'subscriber_id' required") + return None + return await new_subscriber_with_impl(hass, entry, subscriber_id, implementation) + +async def new_subscriber_with_impl( + hass: HomeAssistant, + entry: ConfigEntry, + subscriber_id: str, + implementation: config_entry_oauth2_flow.AbstractOAuth2Implementation, +) -> GoogleNestSubscriber: + """Create a GoogleNestSubscriber, used during ConfigFlow.""" config = hass.data[DOMAIN][DATA_NEST_CONFIG] session = config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation) auth = AsyncConfigEntryAuth( @@ -87,6 +107,4 @@ async def new_subscriber( config[CONF_CLIENT_ID], config[CONF_CLIENT_SECRET], ) - return GoogleNestSubscriber( - auth, config[CONF_PROJECT_ID], config[CONF_SUBSCRIBER_ID] - ) + return GoogleNestSubscriber(auth, config[CONF_PROJECT_ID], subscriber_id) diff --git a/homeassistant/components/nest/config_flow.py b/homeassistant/components/nest/config_flow.py index ec567aaa14e..31192b1a2b2 100644 --- a/homeassistant/components/nest/config_flow.py +++ b/homeassistant/components/nest/config_flow.py @@ -6,7 +6,22 @@ This configuration flow supports the following: - Legacy Nest API auth flow with where user enters an auth code manually NestFlowHandler is an implementation of AbstractOAuth2FlowHandler with -some overrides to support installed app and old APIs auth flow. +some overrides to support installed app and old APIs auth flow, reauth, +and other custom steps inserted in the middle of the flow. + +The notable config flow steps are: +- user: To dispatch between API versions +- auth: Inserted to add a hook for the installed app flow to accept a token +- async_oauth_create_entry: Overridden to handle when OAuth is complete. This + does not actually create the entry, but holds on to the OAuth token data + for later +- pubsub: Configure the pubsub subscription. Note that subscriptions created + by the config flow are deleted when removed. +- finish: Handles creating a new configuration entry or updating the existing + configuration entry for reauth. + +The SDM API config flow supports a hybrid of configuration.yaml (used as defaults) +and config flow. """ from __future__ import annotations @@ -17,20 +32,46 @@ import os from typing import Any import async_timeout +from google_nest_sdm.exceptions import ( + AuthException, + ConfigurationException, + GoogleNestException, +) import voluptuous as vol +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_entry_oauth2_flow +from homeassistant.util import get_random_string from homeassistant.util.json import load_json -from .const import DATA_SDM, DOMAIN, OOB_REDIRECT_URI, SDM_SCOPES +from . import api +from .const import ( + CONF_CLOUD_PROJECT_ID, + CONF_PROJECT_ID, + CONF_SUBSCRIBER_ID, + DATA_NEST_CONFIG, + DATA_SDM, + DOMAIN, + OOB_REDIRECT_URI, + SDM_SCOPES, +) DATA_FLOW_IMPL = "nest_flow_implementation" +SUBSCRIPTION_FORMAT = "projects/{cloud_project_id}/subscriptions/home-assistant-{rnd}" +SUBSCRIPTION_RAND_LENGTH = 10 +CLOUD_CONSOLE_URL = "https://console.cloud.google.com/home/dashboard" _LOGGER = logging.getLogger(__name__) +def _generate_subscription_id(cloud_project_id: str) -> str: + """Create a new subscription id.""" + rnd = get_random_string(SUBSCRIPTION_RAND_LENGTH) + return SUBSCRIPTION_FORMAT.format(cloud_project_id=cloud_project_id, rnd=rnd) + + @callback def register_flow_implementation( hass: HomeAssistant, @@ -80,8 +121,10 @@ class NestFlowHandler( def __init__(self) -> None: """Initialize NestFlowHandler.""" super().__init__() - # When invoked for reauth, allows updating an existing config entry - self._reauth = False + # Allows updating an existing config entry + self._reauth_data: dict[str, Any] | None = None + # ConfigEntry data for SDM API + self._data: dict[str, Any] = {DATA_SDM: {}} @classmethod def register_sdm_api(cls, hass: HomeAssistant) -> None: @@ -110,35 +153,24 @@ class NestFlowHandler( } async def async_oauth_create_entry(self, data: dict[str, Any]) -> FlowResult: - """Create an entry for the SDM flow.""" + """Complete OAuth setup and finish pubsub or finish.""" assert self.is_sdm_api(), "Step only supported for SDM API" - data[DATA_SDM] = {} - await self.async_set_unique_id(DOMAIN) - # Update existing config entry when in the reauth flow. This - # integration only supports one config entry so remove any prior entries - # added before the "single_instance_allowed" check was added - existing_entries = self._async_current_entries() - if existing_entries: - updated = False - for entry in existing_entries: - if updated: - await self.hass.config_entries.async_remove(entry.entry_id) - continue - updated = True - self.hass.config_entries.async_update_entry( - entry, data=data, unique_id=DOMAIN - ) - await self.hass.config_entries.async_reload(entry.entry_id) - return self.async_abort(reason="reauth_successful") - - return await super().async_oauth_create_entry(data) + self._data.update(data) + if not self._configure_pubsub(): + _LOGGER.debug("Skipping Pub/Sub configuration") + return await self.async_step_finish() + return await self.async_step_pubsub() async def async_step_reauth( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Perform reauth upon an API authentication error.""" assert self.is_sdm_api(), "Step only supported for SDM API" - self._reauth = True # Forces update of existing config entry + if user_input is None: + _LOGGER.error("Reauth invoked with empty config entry data") + return self.async_abort(reason="missing_configuration") + self._reauth_data = user_input + self._data.update(user_input) return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( @@ -167,7 +199,7 @@ class NestFlowHandler( """Handle a flow initialized by the user.""" if self.is_sdm_api(): # Reauth will update an existing entry - if self._async_current_entries() and not self._reauth: + if self._async_current_entries() and not self._reauth_data: return self.async_abort(reason="single_instance_allowed") return await super().async_step_user(user_input) return await self.async_step_init(user_input) @@ -199,6 +231,106 @@ class NestFlowHandler( ) return await super().async_step_auth(user_input) + def _configure_pubsub(self) -> bool: + """Return True if the config flow should configure Pub/Sub.""" + if self._reauth_data is not None and CONF_SUBSCRIBER_ID in self._reauth_data: + # Existing entry needs to be reconfigured + return True + if CONF_SUBSCRIBER_ID in self.hass.data[DOMAIN][DATA_NEST_CONFIG]: + # Hard coded configuration.yaml skips pubsub in config flow + return False + # No existing subscription configured, so create in config flow + return True + + async def async_step_pubsub( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Configure and create Pub/Sub subscriber.""" + # Populate data from the previous config entry during reauth, then + # overwrite with the user entered values. + data = {} + if self._reauth_data: + data.update(self._reauth_data) + if user_input: + data.update(user_input) + cloud_project_id = data.get(CONF_CLOUD_PROJECT_ID, "") + + errors = {} + config = self.hass.data[DOMAIN][DATA_NEST_CONFIG] + if cloud_project_id == config[CONF_PROJECT_ID]: + _LOGGER.error( + "Wrong Project ID. Device Access Project ID used, but expected Cloud Project ID" + ) + errors[CONF_CLOUD_PROJECT_ID] = "wrong_project_id" + + if user_input is not None and not errors: + # Create the subscriber id and/or verify it already exists. Note that + # the existing id is used, and create call below is idempotent + subscriber_id = data.get(CONF_SUBSCRIBER_ID, "") + if not subscriber_id: + subscriber_id = _generate_subscription_id(cloud_project_id) + _LOGGER.debug("Creating subscriber id '%s'", subscriber_id) + # Create a placeholder ConfigEntry to use since with the auth we've already created. + entry = ConfigEntry( + version=1, domain=DOMAIN, title="", data=self._data, source="" + ) + subscriber = await api.new_subscriber_with_impl( + self.hass, entry, subscriber_id, self.flow_impl + ) + try: + await subscriber.create_subscription() + except AuthException as err: + _LOGGER.error("Subscriber authentication error: %s", err) + return self.async_abort(reason="invalid_access_token") + except ConfigurationException as err: + _LOGGER.error("Configuration error creating subscription: %s", err) + errors[CONF_CLOUD_PROJECT_ID] = "bad_project_id" + except GoogleNestException as err: + _LOGGER.error("Error creating subscription: %s", err) + errors[CONF_CLOUD_PROJECT_ID] = "subscriber_error" + + if not errors: + self._data.update( + { + CONF_SUBSCRIBER_ID: subscriber_id, + CONF_CLOUD_PROJECT_ID: cloud_project_id, + } + ) + return await self.async_step_finish() + + return self.async_show_form( + step_id="pubsub", + data_schema=vol.Schema( + { + vol.Required(CONF_CLOUD_PROJECT_ID, default=cloud_project_id): str, + } + ), + description_placeholders={"url": CLOUD_CONSOLE_URL}, + errors=errors, + ) + + async def async_step_finish(self, data: dict[str, Any] | None = None) -> FlowResult: + """Create an entry for the SDM flow.""" + assert self.is_sdm_api(), "Step only supported for SDM API" + await self.async_set_unique_id(DOMAIN) + # Update existing config entry when in the reauth flow. This + # integration only supports one config entry so remove any prior entries + # added before the "single_instance_allowed" check was added + existing_entries = self._async_current_entries() + if existing_entries: + updated = False + for entry in existing_entries: + if updated: + await self.hass.config_entries.async_remove(entry.entry_id) + continue + updated = True + self.hass.config_entries.async_update_entry( + entry, data=self._data, unique_id=DOMAIN + ) + await self.hass.config_entries.async_reload(entry.entry_id) + return self.async_abort(reason="reauth_successful") + return await super().async_oauth_create_entry(self._data) + async def async_step_init( self, user_input: dict[str, Any] | None = None ) -> FlowResult: diff --git a/homeassistant/components/nest/const.py b/homeassistant/components/nest/const.py index 6fcd74299ba..a92a48bfd6c 100644 --- a/homeassistant/components/nest/const.py +++ b/homeassistant/components/nest/const.py @@ -7,6 +7,7 @@ DATA_NEST_CONFIG = "nest_config" CONF_PROJECT_ID = "project_id" CONF_SUBSCRIBER_ID = "subscriber_id" +CONF_CLOUD_PROJECT_ID = "cloud_project_id" SIGNAL_NEST_UPDATE = "nest_update" diff --git a/homeassistant/components/nest/strings.json b/homeassistant/components/nest/strings.json index 84cfc3435a6..4dd3e5419b0 100644 --- a/homeassistant/components/nest/strings.json +++ b/homeassistant/components/nest/strings.json @@ -11,6 +11,13 @@ "code": "[%key:common::config_flow::data::access_token%]" } }, + "pubsub": { + "title": "Configure Google Cloud", + "description": "Visit the [Cloud Console]({url}) to find your Google Cloud Project ID.", + "data": { + "cloud_project_id": "Google Cloud Project ID" + } + }, "reauth_confirm": { "title": "[%key:common::config_flow::title::reauth%]", "description": "The Nest integration needs to re-authenticate your account" @@ -34,7 +41,10 @@ "timeout": "Timeout validating code", "invalid_pin": "Invalid [%key:common::config_flow::data::pin%]", "unknown": "[%key:common::config_flow::error::unknown%]", - "internal_error": "Internal error validating code" + "internal_error": "Internal error validating code", + "bad_project_id": "Please enter a valid Cloud Project ID (check Cloud Console)", + "wrong_project_id": "Please enter a valid Cloud Project ID (found Device Access Project ID)", + "subscriber_error": "Unknown subscriber error, see logs" }, "abort": { "single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]", @@ -42,7 +52,8 @@ "authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]", "unknown_authorize_url_generation": "[%key:common::config_flow::abort::unknown_authorize_url_generation%]", "no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]", - "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", + "invalid_access_token": "[%key:common::config_flow::error::invalid_access_token]" }, "create_entry": { "default": "[%key:common::config_flow::create_entry::authenticated%]" diff --git a/homeassistant/components/nest/translations/en.json b/homeassistant/components/nest/translations/en.json index be35cf1b54e..e03de2b9bea 100644 --- a/homeassistant/components/nest/translations/en.json +++ b/homeassistant/components/nest/translations/en.json @@ -56,4 +56,4 @@ "doorbell_chime": "Doorbell pressed" } } -} \ No newline at end of file +} diff --git a/tests/components/nest/common.py b/tests/components/nest/common.py index c9572c528bb..eb44b19d540 100644 --- a/tests/components/nest/common.py +++ b/tests/components/nest/common.py @@ -32,7 +32,7 @@ FAKE_TOKEN = "some-token" FAKE_REFRESH_TOKEN = "some-refresh-token" -def create_config_entry(hass, token_expiration_time=None): +def create_config_entry(hass, token_expiration_time=None) -> MockConfigEntry: """Create a ConfigEntry and add it to Home Assistant.""" if token_expiration_time is None: token_expiration_time = time.time() + 86400 @@ -47,7 +47,9 @@ def create_config_entry(hass, token_expiration_time=None): "expires_at": token_expiration_time, }, } - MockConfigEntry(domain=DOMAIN, data=config_entry_data).add_to_hass(hass) + config_entry = MockConfigEntry(domain=DOMAIN, data=config_entry_data) + config_entry.add_to_hass(hass) + return config_entry class FakeDeviceManager(DeviceManager): @@ -80,6 +82,14 @@ class FakeSubscriber(GoogleNestSubscriber): """Capture the callback set by Home Assistant.""" self._callback = callback + async def create_subscription(self): + """Create the subscription.""" + return + + async def delete_subscription(self): + """Delete the subscription.""" + return + async def start_async(self): """Return the fake device manager.""" return self._device_manager @@ -99,9 +109,12 @@ class FakeSubscriber(GoogleNestSubscriber): await self._callback(event_message) -async def async_setup_sdm_platform(hass, platform, devices={}, structures={}): +async def async_setup_sdm_platform( + hass, platform, devices={}, structures={}, with_config=True +): """Set up the platform and prerequisites.""" - create_config_entry(hass) + if with_config: + create_config_entry(hass) device_manager = FakeDeviceManager(devices=devices, structures=structures) subscriber = FakeSubscriber(device_manager) with patch( diff --git a/tests/components/nest/test_config_flow_sdm.py b/tests/components/nest/test_config_flow_sdm.py index 75ce8bcf939..5d6987f94f7 100644 --- a/tests/components/nest/test_config_flow_sdm.py +++ b/tests/components/nest/test_config_flow_sdm.py @@ -1,7 +1,13 @@ """Test the Google Nest Device Access config flow.""" +import copy from unittest.mock import patch +from google_nest_sdm.exceptions import ( + AuthException, + ConfigurationException, + GoogleNestException, +) import pytest from homeassistant import config_entries, setup @@ -11,12 +17,13 @@ from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET from homeassistant.core import HomeAssistant from homeassistant.helpers import config_entry_oauth2_flow -from .common import MockConfigEntry +from .common import FakeDeviceManager, FakeSubscriber, MockConfigEntry CLIENT_ID = "1234" CLIENT_SECRET = "5678" PROJECT_ID = "project-id-4321" -SUBSCRIBER_ID = "projects/example/subscriptions/subscriber-id-9876" +SUBSCRIBER_ID = "projects/cloud-id-9876/subscriptions/subscriber-id-9876" +CLOUD_PROJECT_ID = "cloud-id-9876" CONFIG = { DOMAIN: { @@ -35,7 +42,19 @@ WEB_REDIRECT_URL = "https://example.com/auth/external/callback" APP_REDIRECT_URL = "urn:ietf:wg:oauth:2.0:oob" -def get_config_entry(hass: HomeAssistant) -> ConfigEntry: +@pytest.fixture +def device_manager() -> FakeDeviceManager: + """Create FakeDeviceManager.""" + return FakeDeviceManager(devices={}, structures={}) + + +@pytest.fixture +def subscriber(device_manager: FakeDeviceManager) -> FakeSubscriber: + """Create FakeSubscriber.""" + return FakeSubscriber(device_manager) + + +def get_config_entry(hass): """Return a single config entry.""" entries = hass.config_entries.async_entries(DOMAIN) assert len(entries) == 1 @@ -71,7 +90,7 @@ class OAuthFixture: result["flow_id"], {"implementation": auth_domain} ) - async def async_oauth_web_flow(self, result: dict) -> ConfigEntry: + async def async_oauth_web_flow(self, result: dict) -> None: """Invoke the oauth flow for Web Auth with fake responses.""" state = self.create_state(result, WEB_REDIRECT_URL) assert result["url"] == self.authorize_url(state, WEB_REDIRECT_URL) @@ -82,9 +101,9 @@ class OAuthFixture: assert resp.status == 200 assert resp.headers["content-type"] == "text/html; charset=utf-8" - return await self.async_finish_flow(result) + await self.async_mock_refresh(result) - async def async_oauth_app_flow(self, result: dict) -> ConfigEntry: + async def async_oauth_app_flow(self, result: dict) -> None: """Invoke the oauth flow for Installed Auth with fake responses.""" # Render form with a link to get an auth token assert result["type"] == "form" @@ -96,7 +115,25 @@ class OAuthFixture: state, APP_REDIRECT_URL ) # Simulate user entering auth token in form - return await self.async_finish_flow(result, {"code": "abcd"}) + await self.async_mock_refresh(result, {"code": "abcd"}) + + async def async_reauth(self, old_data: dict) -> dict: + """Initiate a reuath flow.""" + result = await self.hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_data + ) + assert result["type"] == "form" + assert result["step_id"] == "reauth_confirm" + + # Advance through the reauth flow + flows = self.hass.config_entries.flow.async_progress() + assert len(flows) == 1 + assert flows[0]["step_id"] == "reauth_confirm" + + # Advance to the oauth flow + return await self.hass.config_entries.flow.async_configure( + flows[0]["flow_id"], {} + ) def create_state(self, result: dict, redirect_url: str) -> str: """Create state object based on redirect url.""" @@ -119,7 +156,7 @@ class OAuthFixture: "&access_type=offline&prompt=consent" ) - async def async_finish_flow(self, result, user_input: dict = None) -> ConfigEntry: + async def async_mock_refresh(self, result, user_input: dict = None) -> None: """Finish the OAuth flow exchanging auth token for refresh token.""" self.aioclient_mock.post( OAUTH2_TOKEN, @@ -131,6 +168,10 @@ class OAuthFixture: }, ) + async def async_finish_setup( + self, result: dict, user_input: dict = None + ) -> ConfigEntry: + """Finish the OAuth flow exchanging auth token for refresh token.""" with patch( "homeassistant.components.nest.async_setup_entry", return_value=True ) as mock_setup: @@ -139,7 +180,25 @@ class OAuthFixture: ) assert len(mock_setup.mock_calls) == 1 await self.hass.async_block_till_done() + return self.get_config_entry() + async def async_configure(self, result: dict, user_input: dict) -> dict: + """Advance to the next step in the config flow.""" + return await self.hass.config_entries.flow.async_configure( + result["flow_id"], user_input + ) + + async def async_pubsub_flow(self, result: dict, cloud_project_id="") -> ConfigEntry: + """Verify the pubsub creation step.""" + # Render form with a link to get an auth token + assert result["type"] == "form" + assert result["step_id"] == "pubsub" + assert "description_placeholders" in result + assert "url" in result["description_placeholders"] + assert result["data_schema"]({}) == {"cloud_project_id": cloud_project_id} + + def get_config_entry(self) -> ConfigEntry: + """Get the config entry.""" return get_config_entry(self.hass) @@ -149,6 +208,13 @@ async def oauth(hass, hass_client_no_auth, aioclient_mock, current_request_with_ return OAuthFixture(hass, hass_client_no_auth, aioclient_mock) +async def async_setup_configflow(hass): + """Set up component so the pubsub subscriber is managed by config flow.""" + config = copy.deepcopy(CONFIG) + del config[DOMAIN]["subscriber_id"] # Create in config flow instead + return await setup.async_setup_component(hass, DOMAIN, config) + + async def test_web_full_flow(hass, oauth): """Check full flow.""" assert await setup.async_setup_component(hass, DOMAIN, CONFIG) @@ -159,7 +225,8 @@ async def test_web_full_flow(hass, oauth): result = await oauth.async_pick_flow(result, WEB_AUTH_DOMAIN) - entry = await oauth.async_oauth_web_flow(result) + await oauth.async_oauth_web_flow(result) + entry = await oauth.async_finish_setup(result) assert entry.title == "OAuth for Web" assert "token" in entry.data entry.data["token"].pop("expires_at") @@ -170,6 +237,8 @@ async def test_web_full_flow(hass, oauth): "type": "Bearer", "expires_in": 60, } + # Subscriber from configuration.yaml + assert "subscriber_id" not in entry.data async def test_web_reauth(hass, oauth): @@ -194,19 +263,10 @@ async def test_web_reauth(hass, oauth): "access_token": "some-revoked-token", } - await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data - ) + result = await oauth.async_reauth(old_entry.data) - # Advance through the reauth flow - flows = hass.config_entries.flow.async_progress() - assert len(flows) == 1 - assert flows[0]["step_id"] == "reauth_confirm" - - # Run the oauth flow - result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) - - entry = await oauth.async_oauth_web_flow(result) + await oauth.async_oauth_web_flow(result) + entry = await oauth.async_finish_setup(result) # Verify existing tokens are replaced entry.data["token"].pop("expires_at") assert entry.unique_id == DOMAIN @@ -217,6 +277,7 @@ async def test_web_reauth(hass, oauth): "expires_in": 60, } assert entry.data["auth_implementation"] == WEB_AUTH_DOMAIN + assert "subscriber_id" not in entry.data # not updated async def test_single_config_entry(hass): @@ -254,17 +315,12 @@ async def test_unexpected_existing_config_entries(hass, oauth): assert len(entries) == 2 # Invoke the reauth flow - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data - ) - assert result["type"] == "form" - assert result["step_id"] == "reauth_confirm" + result = await oauth.async_reauth(old_entry.data) - flows = hass.config_entries.flow.async_progress() - - result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) await oauth.async_oauth_web_flow(result) + await oauth.async_finish_setup(result) + # Only a single entry now exists, and the other was cleaned up entries = hass.config_entries.async_entries(DOMAIN) assert len(entries) == 1 @@ -277,9 +333,22 @@ async def test_unexpected_existing_config_entries(hass, oauth): "type": "Bearer", "expires_in": 60, } + assert "subscriber_id" not in entry.data # not updated -async def test_app_full_flow(hass, oauth, aioclient_mock): +async def test_reauth_missing_config_entry(hass): + """Test the reauth flow invoked missing existing data.""" + assert await setup.async_setup_component(hass, DOMAIN, CONFIG) + + # Invoke the reauth flow with no existing data + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=None + ) + assert result["type"] == "abort" + assert result["reason"] == "missing_configuration" + + +async def test_app_full_flow(hass, oauth): """Check full flow.""" assert await setup.async_setup_component(hass, DOMAIN, CONFIG) @@ -288,7 +357,8 @@ async def test_app_full_flow(hass, oauth, aioclient_mock): ) result = await oauth.async_pick_flow(result, APP_AUTH_DOMAIN) - entry = await oauth.async_oauth_app_flow(result) + await oauth.async_oauth_app_flow(result) + entry = await oauth.async_finish_setup(result, {"code": "1234"}) assert entry.title == "OAuth for Apps" assert "token" in entry.data entry.data["token"].pop("expires_at") @@ -299,6 +369,8 @@ async def test_app_full_flow(hass, oauth, aioclient_mock): "type": "Bearer", "expires_in": 60, } + # Subscriber from configuration.yaml + assert "subscriber_id" not in entry.data async def test_app_reauth(hass, oauth): @@ -318,26 +390,11 @@ async def test_app_reauth(hass, oauth): }, ) - entry = get_config_entry(hass) - assert entry.data["token"] == { - "access_token": "some-revoked-token", - } - - await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data - ) - - # Advance through the reauth flow - flows = hass.config_entries.flow.async_progress() - assert len(flows) == 1 - assert flows[0]["step_id"] == "reauth_confirm" - - # Run the oauth flow - result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) + result = await oauth.async_reauth(old_entry.data) await oauth.async_oauth_app_flow(result) # Verify existing tokens are replaced - entry = get_config_entry(hass) + entry = await oauth.async_finish_setup(result, {"code": "1234"}) entry.data["token"].pop("expires_at") assert entry.unique_id == DOMAIN assert entry.data["token"] == { @@ -347,3 +404,186 @@ async def test_app_reauth(hass, oauth): "expires_in": 60, } assert entry.data["auth_implementation"] == APP_AUTH_DOMAIN + assert "subscriber_id" not in entry.data # not updated + + +async def test_pubsub_subscription(hass, oauth, subscriber): + """Check flow that creates a pub/sub subscription.""" + assert await async_setup_configflow(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await oauth.async_pick_flow(result, APP_AUTH_DOMAIN) + await oauth.async_oauth_app_flow(result) + + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber", + return_value=subscriber, + ): + result = await oauth.async_configure(result, {"code": "1234"}) + await oauth.async_pubsub_flow(result) + entry = await oauth.async_finish_setup( + result, {"cloud_project_id": CLOUD_PROJECT_ID} + ) + await hass.async_block_till_done() + + assert entry.title == "OAuth for Apps" + assert "token" in entry.data + entry.data["token"].pop("expires_at") + assert entry.unique_id == DOMAIN + assert entry.data["token"] == { + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + } + assert "subscriber_id" in entry.data + assert entry.data["cloud_project_id"] == CLOUD_PROJECT_ID + + +async def test_pubsub_subscription_auth_failure(hass, oauth): + """Check flow that creates a pub/sub subscription.""" + assert await async_setup_configflow(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await oauth.async_pick_flow(result, APP_AUTH_DOMAIN) + await oauth.async_oauth_app_flow(result) + result = await oauth.async_configure(result, {"code": "1234"}) + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber.create_subscription", + side_effect=AuthException(), + ): + await oauth.async_pubsub_flow(result) + result = await oauth.async_configure( + result, {"cloud_project_id": CLOUD_PROJECT_ID} + ) + await hass.async_block_till_done() + + assert result["type"] == "abort" + assert result["reason"] == "invalid_access_token" + + +async def test_pubsub_subscription_failure(hass, oauth): + """Check flow that creates a pub/sub subscription.""" + assert await async_setup_configflow(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await oauth.async_pick_flow(result, APP_AUTH_DOMAIN) + await oauth.async_oauth_app_flow(result) + result = await oauth.async_configure(result, {"code": "1234"}) + await oauth.async_pubsub_flow(result) + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber.create_subscription", + side_effect=GoogleNestException(), + ): + result = await oauth.async_configure( + result, {"cloud_project_id": CLOUD_PROJECT_ID} + ) + await hass.async_block_till_done() + + assert result["type"] == "form" + assert "errors" in result + assert "cloud_project_id" in result["errors"] + assert result["errors"]["cloud_project_id"] == "subscriber_error" + + +async def test_pubsub_subscription_configuration_failure(hass, oauth): + """Check flow that creates a pub/sub subscription.""" + assert await async_setup_configflow(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await oauth.async_pick_flow(result, APP_AUTH_DOMAIN) + await oauth.async_oauth_app_flow(result) + result = await oauth.async_configure(result, {"code": "1234"}) + await oauth.async_pubsub_flow(result) + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber.create_subscription", + side_effect=ConfigurationException(), + ): + result = await oauth.async_configure( + result, {"cloud_project_id": CLOUD_PROJECT_ID} + ) + await hass.async_block_till_done() + + assert result["type"] == "form" + assert "errors" in result + assert "cloud_project_id" in result["errors"] + assert result["errors"]["cloud_project_id"] == "bad_project_id" + + +async def test_pubsub_with_wrong_project_id(hass, oauth): + """Test a possible common misconfiguration mixing up project ids.""" + assert await async_setup_configflow(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await oauth.async_pick_flow(result, APP_AUTH_DOMAIN) + await oauth.async_oauth_app_flow(result) + result = await oauth.async_configure(result, {"code": "1234"}) + await oauth.async_pubsub_flow(result) + result = await oauth.async_configure( + result, {"cloud_project_id": PROJECT_ID} # SDM project id + ) + await hass.async_block_till_done() + + assert result["type"] == "form" + assert "errors" in result + assert "cloud_project_id" in result["errors"] + assert result["errors"]["cloud_project_id"] == "wrong_project_id" + + +async def test_pubsub_subscriber_config_entry_reauth(hass, oauth, subscriber): + """Test the pubsub subscriber id is preserved during reauth.""" + assert await async_setup_configflow(hass) + + old_entry = create_config_entry( + hass, + { + "auth_implementation": APP_AUTH_DOMAIN, + "subscription_id": SUBSCRIBER_ID, + "cloud_project_id": CLOUD_PROJECT_ID, + "token": { + "access_token": "some-revoked-token", + }, + "sdm": {}, + }, + ) + result = await oauth.async_reauth(old_entry.data) + await oauth.async_oauth_app_flow(result) + result = await oauth.async_configure(result, {"code": "1234"}) + + # Configure Pub/Sub + await oauth.async_pubsub_flow(result, cloud_project_id=CLOUD_PROJECT_ID) + + # Verify existing tokens are replaced + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber", + return_value=subscriber, + ): + entry = await oauth.async_finish_setup( + result, {"cloud_project_id": "other-cloud-project-id"} + ) + await hass.async_block_till_done() + + entry = oauth.get_config_entry() + entry.data["token"].pop("expires_at") + assert entry.unique_id == DOMAIN + assert entry.data["token"] == { + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + } + assert entry.data["auth_implementation"] == APP_AUTH_DOMAIN + assert ( + "projects/other-cloud-project-id/subscriptions" in entry.data["subscriber_id"] + ) + assert entry.data["cloud_project_id"] == "other-cloud-project-id" diff --git a/tests/components/nest/test_init_sdm.py b/tests/components/nest/test_init_sdm.py index 59d1cbd0d69..fbfd6305487 100644 --- a/tests/components/nest/test_init_sdm.py +++ b/tests/components/nest/test_init_sdm.py @@ -9,7 +9,11 @@ import copy import logging from unittest.mock import patch -from google_nest_sdm.exceptions import AuthException, GoogleNestException +from google_nest_sdm.exceptions import ( + AuthException, + ConfigurationException, + GoogleNestException, +) from homeassistant.components.nest import DOMAIN from homeassistant.config_entries import ConfigEntryState @@ -31,9 +35,10 @@ async def test_setup_success(hass, caplog): assert entries[0].state is ConfigEntryState.LOADED -async def async_setup_sdm(hass, config=CONFIG): +async def async_setup_sdm(hass, config=CONFIG, with_config=True): """Prepare test setup.""" - create_config_entry(hass) + if with_config: + create_config_entry(hass) with patch( "homeassistant.helpers.config_entry_oauth2_flow.async_get_config_entry_implementation" ): @@ -111,17 +116,53 @@ async def test_subscriber_auth_failure(hass, caplog): async def test_setup_missing_subscriber_id(hass, caplog): - """Test successful setup.""" + """Test missing susbcriber id from config and config entry.""" config = copy.deepcopy(CONFIG) del config[DOMAIN]["subscriber_id"] - with caplog.at_level(logging.ERROR, logger="homeassistant.components.nest"): + + with caplog.at_level(logging.WARNING, logger="homeassistant.components.nest"): result = await async_setup_sdm(hass, config) - assert not result + assert result assert "Configuration option" in caplog.text entries = hass.config_entries.async_entries(DOMAIN) assert len(entries) == 1 - assert entries[0].state is ConfigEntryState.NOT_LOADED + assert entries[0].state is ConfigEntryState.SETUP_ERROR + + +async def test_setup_subscriber_id_config_entry(hass, caplog): + """Test successful setup with subscriber id in ConfigEntry.""" + config = copy.deepcopy(CONFIG) + subscriber_id = config[DOMAIN]["subscriber_id"] + del config[DOMAIN]["subscriber_id"] + + config_entry = create_config_entry(hass) + data = {**config_entry.data} + data["subscriber_id"] = subscriber_id + hass.config_entries.async_update_entry(config_entry, data=data) + + with caplog.at_level(logging.ERROR, logger="homeassistant.components.nest"): + await async_setup_sdm_platform(hass, PLATFORM, with_config=False) + assert not caplog.records + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + assert entries[0].state is ConfigEntryState.LOADED + + +async def test_subscriber_configuration_failure(hass, caplog): + """Test configuration error.""" + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber.start_async", + side_effect=ConfigurationException(), + ), caplog.at_level(logging.ERROR, logger="homeassistant.components.nest"): + result = await async_setup_sdm(hass, CONFIG) + assert result + assert "Configuration error: " in caplog.text + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + assert entries[0].state is ConfigEntryState.SETUP_ERROR async def test_empty_config(hass, caplog): @@ -133,3 +174,87 @@ async def test_empty_config(hass, caplog): entries = hass.config_entries.async_entries(DOMAIN) assert len(entries) == 0 + + +async def test_unload_entry(hass, caplog): + """Test successful unload of a ConfigEntry.""" + await async_setup_sdm_platform(hass, PLATFORM) + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.state is ConfigEntryState.LOADED + + assert await hass.config_entries.async_unload(entry.entry_id) + assert entry.state == ConfigEntryState.NOT_LOADED + + +async def test_remove_entry(hass, caplog): + """Test successful unload of a ConfigEntry.""" + await async_setup_sdm_platform(hass, PLATFORM) + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.state is ConfigEntryState.LOADED + + assert await hass.config_entries.async_remove(entry.entry_id) + + entries = hass.config_entries.async_entries(DOMAIN) + assert not entries + + +async def test_remove_entry_deletes_subscriber(hass, caplog): + """Test ConfigEntry unload deletes a subscription.""" + config = copy.deepcopy(CONFIG) + subscriber_id = config[DOMAIN]["subscriber_id"] + del config[DOMAIN]["subscriber_id"] + + config_entry = create_config_entry(hass) + data = {**config_entry.data} + data["subscriber_id"] = subscriber_id + hass.config_entries.async_update_entry(config_entry, data=data) + + await async_setup_sdm_platform(hass, PLATFORM, with_config=False) + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.state is ConfigEntryState.LOADED + + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber.delete_subscription", + ) as delete: + assert await hass.config_entries.async_remove(entry.entry_id) + assert delete.called + + entries = hass.config_entries.async_entries(DOMAIN) + assert not entries + + +async def test_remove_entry_delete_subscriber_failure(hass, caplog): + """Test a failure when deleting the subscription.""" + config = copy.deepcopy(CONFIG) + subscriber_id = config[DOMAIN]["subscriber_id"] + del config[DOMAIN]["subscriber_id"] + + config_entry = create_config_entry(hass) + data = {**config_entry.data} + data["subscriber_id"] = subscriber_id + hass.config_entries.async_update_entry(config_entry, data=data) + + await async_setup_sdm_platform(hass, PLATFORM, with_config=False) + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.state is ConfigEntryState.LOADED + + with patch( + "homeassistant.components.nest.api.GoogleNestSubscriber.delete_subscription", + side_effect=GoogleNestException(), + ): + assert await hass.config_entries.async_remove(entry.entry_id) + + entries = hass.config_entries.async_entries(DOMAIN) + assert not entries