diff --git a/homeassistant/components/nest/__init__.py b/homeassistant/components/nest/__init__.py index 97c9da5794b..151b1dac000 100644 --- a/homeassistant/components/nest/__init__.py +++ b/homeassistant/components/nest/__init__.py @@ -6,14 +6,14 @@ import logging import threading from google_nest_sdm.event import AsyncEventCallback, EventMessage -from google_nest_sdm.exceptions import GoogleNestException +from google_nest_sdm.exceptions import AuthException, GoogleNestException from google_nest_sdm.google_nest_subscriber import GoogleNestSubscriber from nest import Nest from nest.nest import APIError, AuthorizationError import voluptuous as vol from homeassistant import config_entries -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry from homeassistant.const import ( CONF_BINARY_SENSORS, CONF_CLIENT_ID, @@ -231,6 +231,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): try: await subscriber.start_async() + except AuthException as err: + _LOGGER.debug("Subscriber authentication error: %s", err) + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_REAUTH}, + data=entry.data, + ) + ) + return False except GoogleNestException as err: _LOGGER.error("Subscriber error: %s", err) subscriber.stop_async() diff --git a/homeassistant/components/nest/config_flow.py b/homeassistant/components/nest/config_flow.py index 6aaa5bcc489..36b0da239a9 100644 --- a/homeassistant/components/nest/config_flow.py +++ b/homeassistant/components/nest/config_flow.py @@ -75,6 +75,12 @@ class NestFlowHandler( VERSION = 1 CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_PUSH + def __init__(self): + """Initialize NestFlowHandler.""" + super().__init__() + # When invoked for reauth, allows updating an existing config entry + self._reauth = False + @classmethod def register_sdm_api(cls, hass): """Configure the flow handler to use the SDM API.""" @@ -103,19 +109,56 @@ class NestFlowHandler( async def async_oauth_create_entry(self, data: dict) -> dict: """Create an entry for the SDM flow.""" + assert self.is_sdm_api(), "Step only supported for SDM API" data[DATA_SDM] = {} + await self.async_set_unique_id(DOMAIN) + # Update existing config entry when in the reauth flow. This + # integration only supports one config entry so remove any prior entries + # added before the "single_instance_allowed" check was added + existing_entries = self.hass.config_entries.async_entries(DOMAIN) + if existing_entries: + updated = False + for entry in existing_entries: + if updated: + await self.hass.config_entries.async_remove(entry.entry_id) + continue + updated = True + self.hass.config_entries.async_update_entry( + entry, data=data, unique_id=DOMAIN + ) + await self.hass.config_entries.async_reload(entry.entry_id) + return self.async_abort(reason="reauth_successful") + return await super().async_oauth_create_entry(data) + async def async_step_reauth(self, user_input=None): + """Perform reauth upon an API authentication error.""" + assert self.is_sdm_api(), "Step only supported for SDM API" + self._reauth = True # Forces update of existing config entry + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm(self, user_input=None): + """Confirm reauth dialog.""" + assert self.is_sdm_api(), "Step only supported for SDM API" + if user_input is None: + return self.async_show_form( + step_id="reauth_confirm", + data_schema=vol.Schema({}), + ) + return await self.async_step_user() + async def async_step_user(self, user_input=None): """Handle a flow initialized by the user.""" if self.is_sdm_api(): + # Reauth will update an existing entry + if self.hass.config_entries.async_entries(DOMAIN) and not self._reauth: + return self.async_abort(reason="single_instance_allowed") return await super().async_step_user(user_input) return await self.async_step_init(user_input) async def async_step_init(self, user_input=None): """Handle a flow start.""" - if self.is_sdm_api(): - raise UnexpectedStateError("Step only supported for legacy API") + assert not self.is_sdm_api(), "Step only supported for legacy API" flows = self.hass.data.get(DATA_FLOW_IMPL, {}) @@ -145,8 +188,7 @@ class NestFlowHandler( implementation type we expect a pin or an external component to deliver the authentication code. """ - if self.is_sdm_api(): - raise UnexpectedStateError("Step only supported for legacy API") + assert not self.is_sdm_api(), "Step only supported for legacy API" flow = self.hass.data[DATA_FLOW_IMPL][self.flow_impl] @@ -188,8 +230,7 @@ class NestFlowHandler( async def async_step_import(self, info): """Import existing auth from Nest.""" - if self.is_sdm_api(): - raise UnexpectedStateError("Step only supported for legacy API") + assert not self.is_sdm_api(), "Step only supported for legacy API" if self.hass.config_entries.async_entries(DOMAIN): return self.async_abort(reason="single_instance_allowed") diff --git a/homeassistant/components/nest/strings.json b/homeassistant/components/nest/strings.json index f945469e26f..6ce529621aa 100644 --- a/homeassistant/components/nest/strings.json +++ b/homeassistant/components/nest/strings.json @@ -4,6 +4,10 @@ "pick_implementation": { "title": "[%key:common::config_flow::title::oauth2_pick_implementation%]" }, + "reauth_confirm": { + "title": "[%key:common::config_flow::title::reauth%]", + "description": "The Nest integration needs to re-authenticate your account" + }, "init": { "title": "Authentication Provider", "description": "[%key:common::config_flow::title::oauth2_pick_implementation%]", @@ -30,7 +34,8 @@ "missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]", "authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]", "unknown_authorize_url_generation": "[%key:common::config_flow::abort::unknown_authorize_url_generation%]", - "no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]" + "no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" }, "create_entry": { "default": "[%key:common::config_flow::create_entry::authenticated%]" diff --git a/tests/components/nest/test_config_flow_sdm.py b/tests/components/nest/test_config_flow_sdm.py index 6573b17980e..e506f269d66 100644 --- a/tests/components/nest/test_config_flow_sdm.py +++ b/tests/components/nest/test_config_flow_sdm.py @@ -1,9 +1,14 @@ """Test the Google Nest Device Access config flow.""" + +import pytest + from homeassistant import config_entries, setup from homeassistant.components.nest.const import DOMAIN, OAUTH2_AUTHORIZE, OAUTH2_TOKEN from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET from homeassistant.helpers import config_entry_oauth2_flow +from .common import MockConfigEntry + from tests.async_mock import patch CLIENT_ID = "1234" @@ -11,64 +16,210 @@ CLIENT_SECRET = "5678" PROJECT_ID = "project-id-4321" SUBSCRIBER_ID = "subscriber-id-9876" +CONFIG = { + DOMAIN: { + "project_id": PROJECT_ID, + "subscriber_id": SUBSCRIBER_ID, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + }, + "http": {"base_url": "https://example.com"}, +} -async def test_full_flow( - hass, aiohttp_client, aioclient_mock, current_request_with_host -): - """Check full flow.""" - assert await setup.async_setup_component( - hass, - DOMAIN, - { - DOMAIN: { - "project_id": PROJECT_ID, - "subscriber_id": SUBSCRIBER_ID, - CONF_CLIENT_ID: CLIENT_ID, - CONF_CLIENT_SECRET: CLIENT_SECRET, + +def get_config_entry(hass): + """Return a single config entry.""" + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + return entries[0] + + +class OAuthFixture: + """Simulate the oauth flow used by the config flow.""" + + def __init__(self, hass, aiohttp_client, aioclient_mock): + """Initialize OAuthFixture.""" + self.hass = hass + self.aiohttp_client = aiohttp_client + self.aioclient_mock = aioclient_mock + + async def async_oauth_flow(self, result): + """Invoke the oauth flow with fake responses.""" + state = config_entry_oauth2_flow._encode_jwt( + self.hass, + { + "flow_id": result["flow_id"], + "redirect_uri": "https://example.com/auth/external/callback", }, - "http": {"base_url": "https://example.com"}, - }, - ) + ) + + oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID) + assert result["type"] == "external" + assert result["url"] == ( + f"{oauth_authorize}?response_type=code&client_id={CLIENT_ID}" + "&redirect_uri=https://example.com/auth/external/callback" + f"&state={state}&scope=https://www.googleapis.com/auth/sdm.service" + "+https://www.googleapis.com/auth/pubsub" + "&access_type=offline&prompt=consent" + ) + + client = await self.aiohttp_client(self.hass.http.app) + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + + self.aioclient_mock.post( + OAUTH2_TOKEN, + json={ + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + }, + ) + + with patch( + "homeassistant.components.nest.async_setup_entry", return_value=True + ) as mock_setup: + await self.hass.config_entries.flow.async_configure(result["flow_id"]) + assert len(mock_setup.mock_calls) == 1 + + +@pytest.fixture +async def oauth(hass, aiohttp_client, aioclient_mock, current_request_with_host): + """Create the simulated oauth flow.""" + return OAuthFixture(hass, aiohttp_client, aioclient_mock) + + +async def test_full_flow(hass, oauth): + """Check full flow.""" + assert await setup.async_setup_component(hass, DOMAIN, CONFIG) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - state = config_entry_oauth2_flow._encode_jwt( - hass, - { - "flow_id": result["flow_id"], - "redirect_uri": "https://example.com/auth/external/callback", + await oauth.async_oauth_flow(result) + + entry = get_config_entry(hass) + assert entry.title == "Configuration.yaml" + assert "token" in entry.data + entry.data["token"].pop("expires_at") + assert entry.unique_id == DOMAIN + assert entry.data["token"] == { + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + } + + +async def test_reauth(hass, oauth): + """Test Nest reauthentication.""" + + assert await setup.async_setup_component(hass, DOMAIN, CONFIG) + + old_entry = MockConfigEntry( + domain=DOMAIN, + data={ + "auth_implementation": DOMAIN, + "token": { + # Verify this is replaced at end of the test + "access_token": "some-revoked-token", + }, + "sdm": {}, }, + unique_id=DOMAIN, + ) + old_entry.add_to_hass(hass) + + entry = get_config_entry(hass) + assert entry.data["token"] == { + "access_token": "some-revoked-token", + } + + await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data ) - oauth_authorize = OAUTH2_AUTHORIZE.format(project_id=PROJECT_ID) - assert result["url"] == ( - f"{oauth_authorize}?response_type=code&client_id={CLIENT_ID}" - "&redirect_uri=https://example.com/auth/external/callback" - f"&state={state}&scope=https://www.googleapis.com/auth/sdm.service" - "+https://www.googleapis.com/auth/pubsub" - "&access_type=offline&prompt=consent" + # Advance through the reauth flow + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + assert flows[0]["step_id"] == "reauth_confirm" + + # Run the oauth flow + result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) + await oauth.async_oauth_flow(result) + + # Verify existing tokens are replaced + entry = get_config_entry(hass) + entry.data["token"].pop("expires_at") + assert entry.unique_id == DOMAIN + assert entry.data["token"] == { + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + } + + +async def test_single_config_entry(hass): + """Test that only a single config entry is allowed.""" + old_entry = MockConfigEntry( + domain=DOMAIN, data={"auth_implementation": DOMAIN, "sdm": {}} ) + old_entry.add_to_hass(hass) - client = await aiohttp_client(hass.http.app) - resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") - assert resp.status == 200 - assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert await setup.async_setup_component(hass, DOMAIN, CONFIG) - aioclient_mock.post( - OAUTH2_TOKEN, - json={ - "refresh_token": "mock-refresh-token", - "access_token": "mock-access-token", - "type": "Bearer", - "expires_in": 60, - }, + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} ) + assert result["type"] == "abort" + assert result["reason"] == "single_instance_allowed" - with patch( - "homeassistant.components.nest.async_setup_entry", return_value=True - ) as mock_setup: - await hass.config_entries.flow.async_configure(result["flow_id"]) - assert len(hass.config_entries.async_entries(DOMAIN)) == 1 - assert len(mock_setup.mock_calls) == 1 +async def test_unexpected_existing_config_entries(hass, oauth): + """Test Nest reauthentication with multiple existing config entries.""" + # Note that this case will not happen in the future since only a single + # instance is now allowed, but this may have been allowed in the past. + # On reauth, only one entry is kept and the others are deleted. + + assert await setup.async_setup_component(hass, DOMAIN, CONFIG) + + old_entry = MockConfigEntry( + domain=DOMAIN, data={"auth_implementation": DOMAIN, "sdm": {}} + ) + old_entry.add_to_hass(hass) + + old_entry = MockConfigEntry( + domain=DOMAIN, data={"auth_implementation": DOMAIN, "sdm": {}} + ) + old_entry.add_to_hass(hass) + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 2 + + # Invoke the reauth flow + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_REAUTH}, data=old_entry.data + ) + assert result["type"] == "form" + assert result["step_id"] == "reauth_confirm" + + flows = hass.config_entries.flow.async_progress() + + result = await hass.config_entries.flow.async_configure(flows[0]["flow_id"], {}) + await oauth.async_oauth_flow(result) + + # Only a single entry now exists, and the other was cleaned up + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.unique_id == DOMAIN + entry.data["token"].pop("expires_at") + assert entry.data["token"] == { + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + }