From 7050dbb66dbe952e94153655a068bc192c288a3b Mon Sep 17 00:00:00 2001 From: David Knowles Date: Fri, 24 Jan 2025 08:13:54 -0500 Subject: [PATCH] Refactor the Hydrawise config flow (#135886) Co-authored-by: Joost Lekkerkerker --- .../components/hydrawise/config_flow.py | 143 ++++++++++-------- .../components/hydrawise/strings.json | 7 + .../components/hydrawise/test_config_flow.py | 80 +++++++--- 3 files changed, 146 insertions(+), 84 deletions(-) diff --git a/homeassistant/components/hydrawise/config_flow.py b/homeassistant/components/hydrawise/config_flow.py index 5af32af3951..ed21e96cd0b 100644 --- a/homeassistant/components/hydrawise/config_flow.py +++ b/homeassistant/components/hydrawise/config_flow.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Mapping from typing import Any from aiohttp import ClientError @@ -10,85 +10,104 @@ from pydrawise import auth as pydrawise_auth, client from pydrawise.exceptions import NotAuthorizedError import voluptuous as vol -from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlow, ConfigFlowResult +from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from .const import APP_ID, DOMAIN, LOGGER +STEP_USER_DATA_SCHEMA = vol.Schema( + {vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str} +) +STEP_REAUTH_DATA_SCHEMA = vol.Schema({vol.Required(CONF_PASSWORD): str}) + class HydrawiseConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Hydrawise.""" VERSION = 1 - async def _create_or_update_entry( - self, - username: str, - password: str, - *, - on_failure: Callable[[str], ConfigFlowResult], - ) -> ConfigFlowResult: - """Create the config entry.""" - # Verify that the provided credentials work.""" - auth = pydrawise_auth.Auth(username, password) - try: - await auth.token() - except NotAuthorizedError: - return on_failure("invalid_auth") - except TimeoutError: - return on_failure("timeout_connect") - - try: - api = client.Hydrawise(auth, app_id=APP_ID) - # Don't fetch zones because we don't need them yet. - user = await api.get_user(fetch_zones=False) - except TimeoutError: - return on_failure("timeout_connect") - except ClientError as ex: - LOGGER.error("Unable to connect to Hydrawise cloud service: %s", ex) - return on_failure("cannot_connect") - - await self.async_set_unique_id(f"hydrawise-{user.customer_id}") - - if self.source != SOURCE_REAUTH: - self._abort_if_unique_id_configured() - return self.async_create_entry( - title="Hydrawise", - data={CONF_USERNAME: username, CONF_PASSWORD: password}, - ) - - return self.async_update_reload_and_abort( - self._get_reauth_entry(), - data_updates={CONF_USERNAME: username, CONF_PASSWORD: password}, - ) - async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle the initial setup.""" - if user_input is not None: - username = user_input[CONF_USERNAME] - password = user_input[CONF_PASSWORD] + if user_input is None: + return self._show_user_form({}) + username = user_input[CONF_USERNAME] + password = user_input[CONF_PASSWORD] + unique_id, errors = await _authenticate(username, password) + if errors: + return self._show_user_form(errors) + await self.async_set_unique_id(unique_id) + self._abort_if_unique_id_configured() + return self.async_create_entry( + title=username, + data={CONF_USERNAME: username, CONF_PASSWORD: password}, + ) - return await self._create_or_update_entry( - username=username, password=password, on_failure=self._show_form - ) - return self._show_form() - - def _show_form(self, error_type: str | None = None) -> ConfigFlowResult: - errors = {} - if error_type is not None: - errors["base"] = error_type + def _show_user_form(self, errors: dict[str, str]) -> ConfigFlowResult: return self.async_show_form( - step_id="user", - data_schema=vol.Schema( - {vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str} - ), - errors=errors, + step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors ) async def async_step_reauth( self, entry_data: Mapping[str, Any] ) -> ConfigFlowResult: - """Perform reauth after updating config to username/password.""" - return await self.async_step_user() + """Handle reauth upon an API authentication error.""" + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Dialog that informs the user that reauth is required.""" + if user_input is None: + return self._show_reauth_form({}) + + reauth_entry = self._get_reauth_entry() + username = reauth_entry.data[CONF_USERNAME] + password = user_input[CONF_PASSWORD] + user_id, errors = await _authenticate(username, password) + if user_id is None: + return self._show_reauth_form(errors) + + await self.async_set_unique_id(user_id) + self._abort_if_unique_id_mismatch(reason="wrong_account") + return self.async_update_reload_and_abort( + reauth_entry, data={CONF_USERNAME: username, CONF_PASSWORD: password} + ) + + def _show_reauth_form(self, errors: dict[str, str]) -> ConfigFlowResult: + return self.async_show_form( + step_id="reauth_confirm", data_schema=STEP_REAUTH_DATA_SCHEMA, errors=errors + ) + + +async def _authenticate( + username: str, password: str +) -> tuple[str | None, dict[str, str]]: + """Authenticate with the Hydrawise API.""" + unique_id = None + errors: dict[str, str] = {} + auth = pydrawise_auth.Auth(username, password) + try: + await auth.token() + except NotAuthorizedError: + errors["base"] = "invalid_auth" + except TimeoutError: + errors["base"] = "timeout_connect" + + if errors: + return unique_id, errors + + try: + api = client.Hydrawise(auth, app_id=APP_ID) + # Don't fetch zones because we don't need them yet. + user = await api.get_user(fetch_zones=False) + except TimeoutError: + errors["base"] = "timeout_connect" + except ClientError as ex: + LOGGER.error("Unable to connect to Hydrawise cloud service: %s", ex) + errors["base"] = "cannot_connect" + else: + unique_id = f"hydrawise-{user.customer_id}" + + return unique_id, errors diff --git a/homeassistant/components/hydrawise/strings.json b/homeassistant/components/hydrawise/strings.json index 4d50f10bcb2..74c63cbe758 100644 --- a/homeassistant/components/hydrawise/strings.json +++ b/homeassistant/components/hydrawise/strings.json @@ -8,6 +8,13 @@ "username": "[%key:common::config_flow::data::username%]", "password": "[%key:common::config_flow::data::password%]" } + }, + "reauth_confirm": { + "title": "[%key:common::config_flow::title::reauth%]", + "description": "The Hydrawise integration needs to re-authenticate your account", + "data": { + "password": "[%key:common::config_flow::data::password%]" + } } }, "error": { diff --git a/tests/components/hydrawise/test_config_flow.py b/tests/components/hydrawise/test_config_flow.py index 4d25fd5840b..cf723d885e1 100644 --- a/tests/components/hydrawise/test_config_flow.py +++ b/tests/components/hydrawise/test_config_flow.py @@ -9,7 +9,7 @@ import pytest from homeassistant import config_entries from homeassistant.components.hydrawise.const import DOMAIN -from homeassistant.const import CONF_API_KEY, CONF_PASSWORD, CONF_USERNAME +from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -33,16 +33,16 @@ async def test_form( assert result["step_id"] == "user" assert result["errors"] == {} - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( result["flow_id"], {CONF_USERNAME: "asdf@asdf.com", CONF_PASSWORD: "__password__"}, ) mock_pydrawise.get_user.return_value = user await hass.async_block_till_done() - assert result2["type"] is FlowResultType.CREATE_ENTRY - assert result2["title"] == "Hydrawise" - assert result2["data"] == { + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == "asdf@asdf.com" + assert result["data"] == { CONF_USERNAME: "asdf@asdf.com", CONF_PASSWORD: "__password__", } @@ -69,14 +69,14 @@ async def test_form_api_error( mock_pydrawise.get_user.reset_mock(side_effect=True) mock_pydrawise.get_user.return_value = user - result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data) - assert result2["type"] is FlowResultType.CREATE_ENTRY + result = await hass.config_entries.flow.async_configure(result["flow_id"], data) + assert result["type"] is FlowResultType.CREATE_ENTRY async def test_form_auth_connect_timeout( hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock ) -> None: - """Test we handle API errors.""" + """Test we handle connection timeout errors.""" mock_auth.token.side_effect = TimeoutError init_result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} @@ -90,8 +90,8 @@ async def test_form_auth_connect_timeout( assert result["errors"] == {"base": "timeout_connect"} mock_auth.token.reset_mock(side_effect=True) - result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data) - assert result2["type"] is FlowResultType.CREATE_ENTRY + result = await hass.config_entries.flow.async_configure(result["flow_id"], data) + assert result["type"] is FlowResultType.CREATE_ENTRY async def test_form_client_connect_timeout( @@ -112,8 +112,8 @@ async def test_form_client_connect_timeout( mock_pydrawise.get_user.reset_mock(side_effect=True) mock_pydrawise.get_user.return_value = user - result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data) - assert result2["type"] is FlowResultType.CREATE_ENTRY + result = await hass.config_entries.flow.async_configure(result["flow_id"], data) + assert result["type"] is FlowResultType.CREATE_ENTRY async def test_form_not_authorized_error( @@ -133,8 +133,8 @@ async def test_form_not_authorized_error( assert result["errors"] == {"base": "invalid_auth"} mock_auth.token.reset_mock(side_effect=True) - result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data) - assert result2["type"] is FlowResultType.CREATE_ENTRY + result = await hass.config_entries.flow.async_configure(result["flow_id"], data) + assert result["type"] is FlowResultType.CREATE_ENTRY async def test_reauth( @@ -148,7 +148,8 @@ async def test_reauth( title="Hydrawise", domain=DOMAIN, data={ - CONF_API_KEY: "__api_key__", + CONF_USERNAME: "asdf@asdf.com", + CONF_PASSWORD: "bad-password", }, unique_id="hydrawise-12345", ) @@ -160,14 +161,49 @@ async def test_reauth( flows = hass.config_entries.flow.async_progress() assert len(flows) == 1 [result] = flows - assert result["step_id"] == "user" + assert result["step_id"] == "reauth_confirm" - result2 = await hass.config_entries.flow.async_configure( - result["flow_id"], - {CONF_USERNAME: "asdf@asdf.com", CONF_PASSWORD: "__password__"}, - ) mock_pydrawise.get_user.return_value = user + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_PASSWORD: "__password__"} + ) await hass.async_block_till_done() - assert result2["type"] is FlowResultType.ABORT - assert result2["reason"] == "reauth_successful" + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + + +async def test_reauth_fails( + hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock, user: User +) -> None: + """Test that the reauth flow handles API errors.""" + mock_config_entry = MockConfigEntry( + title="Hydrawise", + domain=DOMAIN, + data={ + CONF_USERNAME: "asdf@asdf.com", + CONF_PASSWORD: "bad-password", + }, + unique_id="hydrawise-12345", + ) + mock_config_entry.add_to_hass(hass) + + result = await mock_config_entry.start_reauth_flow(hass) + assert result["step_id"] == "reauth_confirm" + + mock_auth.token.side_effect = NotAuthorizedError + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_PASSWORD: "__password__"} + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {"base": "invalid_auth"} + + mock_auth.token.reset_mock(side_effect=True) + mock_pydrawise.get_user.return_value = user + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {CONF_PASSWORD: "__password__"} + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful"