diff --git a/homeassistant/components/weheat/__init__.py b/homeassistant/components/weheat/__init__.py index 4800046926d..d924d6ceaab 100644 --- a/homeassistant/components/weheat/__init__.py +++ b/homeassistant/components/weheat/__init__.py @@ -3,10 +3,12 @@ from __future__ import annotations from weheat.abstractions.discovery import HeatPumpDiscovery +from weheat.exceptions import UnauthorizedException from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ACCESS_TOKEN, Platform from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.config_entry_oauth2_flow import ( OAuth2Session, async_get_config_entry_implementation, @@ -30,7 +32,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: WeheatConfigEntry) -> bo entry.runtime_data = [] # fetch a list of the heat pumps the entry can access - for pump_info in await HeatPumpDiscovery.discover_active(API_URL, token): + try: + discovered_heat_pumps = await HeatPumpDiscovery.discover_active(API_URL, token) + except UnauthorizedException as error: + raise ConfigEntryAuthFailed from error + + for pump_info in discovered_heat_pumps: LOGGER.debug("Adding %s", pump_info) # for each pump, add a coordinator new_coordinator = WeheatDataUpdateCoordinator(hass, session, pump_info) diff --git a/homeassistant/components/weheat/config_flow.py b/homeassistant/components/weheat/config_flow.py index 707c2f6bc97..c1eccaf6ba7 100644 --- a/homeassistant/components/weheat/config_flow.py +++ b/homeassistant/components/weheat/config_flow.py @@ -1,10 +1,12 @@ """Config flow for Weheat.""" +from collections.abc import Mapping import logging +from typing import Any from weheat.abstractions.user import get_user_id_from_token -from homeassistant.config_entries import ConfigFlowResult +from homeassistant.config_entries import ConfigEntry, ConfigFlowResult from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.helpers.config_entry_oauth2_flow import AbstractOAuth2FlowHandler @@ -16,6 +18,8 @@ class OAuth2FlowHandler(AbstractOAuth2FlowHandler, domain=DOMAIN): DOMAIN = DOMAIN + reauth_entry: ConfigEntry | None = None + @property def logger(self) -> logging.Logger: """Return logger.""" @@ -34,7 +38,34 @@ class OAuth2FlowHandler(AbstractOAuth2FlowHandler, domain=DOMAIN): user_id = await get_user_id_from_token( API_URL, data[CONF_TOKEN][CONF_ACCESS_TOKEN] ) - await self.async_set_unique_id(user_id) - self._abort_if_unique_id_configured() + if not self.reauth_entry: + await self.async_set_unique_id(user_id) + self._abort_if_unique_id_configured() - return self.async_create_entry(title=ENTRY_TITLE, data=data) + return self.async_create_entry(title=ENTRY_TITLE, data=data) + + if self.reauth_entry.unique_id == user_id: + return self.async_update_reload_and_abort( + self.reauth_entry, + unique_id=user_id, + data={**self.reauth_entry.data, **data}, + ) + + return self.async_abort(reason="wrong_account") + + async def async_step_reauth( + self, entry_data: Mapping[str, Any] + ) -> ConfigFlowResult: + """Perform reauth upon an API authentication error.""" + self.reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Confirm reauth dialog.""" + if user_input is None: + return self.async_show_form(step_id="reauth_confirm") + return await self.async_step_user() diff --git a/homeassistant/components/weheat/coordinator.py b/homeassistant/components/weheat/coordinator.py index 69d1319ed52..a50e9daec18 100644 --- a/homeassistant/components/weheat/coordinator.py +++ b/homeassistant/components/weheat/coordinator.py @@ -15,6 +15,7 @@ from weheat.exceptions import ( from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed @@ -24,7 +25,6 @@ EXCEPTIONS = ( ServiceException, NotFoundException, ForbiddenException, - UnauthorizedException, BadRequestException, ApiException, ) @@ -72,6 +72,8 @@ class WeheatDataUpdateCoordinator(DataUpdateCoordinator[HeatPump]): """Get the data from the API.""" try: self._heat_pump_data.get_status(self.session.token[CONF_ACCESS_TOKEN]) + except UnauthorizedException as error: + raise ConfigEntryAuthFailed from error except EXCEPTIONS as error: raise UpdateFailed(error) from error diff --git a/homeassistant/components/weheat/strings.json b/homeassistant/components/weheat/strings.json index b77af4ed306..3982bfd23b3 100644 --- a/homeassistant/components/weheat/strings.json +++ b/homeassistant/components/weheat/strings.json @@ -24,7 +24,8 @@ "no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]", "user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]", "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", - "no_devices_found": "Could not find any heat pumps on this account" + "no_devices_found": "Could not find any heat pumps on this account", + "wrong_account": "You can only reauthenticate this account with the same user." }, "create_entry": { "default": "[%key:common::config_flow::create_entry::authenticated%]" diff --git a/tests/components/weheat/conftest.py b/tests/components/weheat/conftest.py index 1b4bf26c35f..622882d6e8d 100644 --- a/tests/components/weheat/conftest.py +++ b/tests/components/weheat/conftest.py @@ -17,7 +17,14 @@ from homeassistant.components.weheat.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component -from .const import CLIENT_ID, CLIENT_SECRET, TEST_HP_UUID, TEST_MODEL, TEST_SN +from .const import ( + CLIENT_ID, + CLIENT_SECRET, + TEST_HP_UUID, + TEST_MODEL, + TEST_SN, + USER_UUID_1, +) from tests.common import MockConfigEntry @@ -69,6 +76,18 @@ def mock_config_entry() -> MockConfigEntry: ) +@pytest.fixture +def mock_user_id() -> Generator[AsyncMock]: + """Mock the user API call.""" + with ( + patch( + "homeassistant.components.weheat.config_flow.get_user_id_from_token", + return_value=USER_UUID_1, + ) as user_mock, + ): + yield user_mock + + @pytest.fixture def mock_weheat_discover(mock_heat_pump_info) -> Generator[AsyncMock]: """Mock an Weheat discovery.""" diff --git a/tests/components/weheat/const.py b/tests/components/weheat/const.py index bae74dc70a1..61203259c58 100644 --- a/tests/components/weheat/const.py +++ b/tests/components/weheat/const.py @@ -4,6 +4,7 @@ CLIENT_ID = "1234" CLIENT_SECRET = "5678" USER_UUID_1 = "0000-1111-2222-3333" +USER_UUID_2 = "0000-1111-2222-4444" CONF_REFRESH_TOKEN = "refresh_token" CONF_AUTH_IMPLEMENTATION = "auth_implementation" diff --git a/tests/components/weheat/test_config_flow.py b/tests/components/weheat/test_config_flow.py index c065d011e42..b33dd0a8db8 100644 --- a/tests/components/weheat/test_config_flow.py +++ b/tests/components/weheat/test_config_flow.py @@ -1,6 +1,6 @@ """Test the Weheat config flow.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -23,6 +23,7 @@ from .const import ( MOCK_ACCESS_TOKEN, MOCK_REFRESH_TOKEN, USER_UUID_1, + USER_UUID_2, ) from tests.common import MockConfigEntry @@ -99,6 +100,51 @@ async def test_duplicate_unique_id( assert result["reason"] == "already_configured" +@pytest.mark.usefixtures("current_request_with_host") +@pytest.mark.parametrize( + ("logged_in_user", "expected_reason"), + [(USER_UUID_1, "reauth_successful"), (USER_UUID_2, "wrong_account")], +) +async def test_reauth( + hass: HomeAssistant, + hass_client_no_auth: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, + mock_user_id: AsyncMock, + mock_weheat_discover: AsyncMock, + setup_credentials, + logged_in_user: str, + expected_reason: str, +) -> None: + """Check reauth flow both with and without the correct logged in user.""" + mock_user_id.return_value = logged_in_user + entry = MockConfigEntry( + domain=DOMAIN, + data={}, + unique_id=USER_UUID_1, + ) + + entry.add_to_hass(hass) + + result = await entry.start_reauth_flow(hass) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure( + flow_id=result["flow_id"], + user_input={}, + ) + + await handle_oauth(hass, hass_client_no_auth, aioclient_mock, result) + + assert result["type"] is FlowResultType.EXTERNAL_STEP + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result.get("type") is FlowResultType.ABORT + assert result.get("reason") == expected_reason + assert entry.unique_id == USER_UUID_1 + + async def handle_oauth( hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator,