mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Refactor the Hydrawise config flow (#135886)
Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
parent
47efb68780
commit
7050dbb66d
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientError
|
||||
@ -10,85 +10,104 @@ from pydrawise import auth as pydrawise_auth, client
|
||||
from pydrawise.exceptions import NotAuthorizedError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlow, ConfigFlowResult
|
||||
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
||||
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
|
||||
|
||||
from .const import APP_ID, DOMAIN, LOGGER
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str}
|
||||
)
|
||||
STEP_REAUTH_DATA_SCHEMA = vol.Schema({vol.Required(CONF_PASSWORD): str})
|
||||
|
||||
|
||||
class HydrawiseConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Hydrawise."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def _create_or_update_entry(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
*,
|
||||
on_failure: Callable[[str], ConfigFlowResult],
|
||||
) -> ConfigFlowResult:
|
||||
"""Create the config entry."""
|
||||
# Verify that the provided credentials work."""
|
||||
auth = pydrawise_auth.Auth(username, password)
|
||||
try:
|
||||
await auth.token()
|
||||
except NotAuthorizedError:
|
||||
return on_failure("invalid_auth")
|
||||
except TimeoutError:
|
||||
return on_failure("timeout_connect")
|
||||
|
||||
try:
|
||||
api = client.Hydrawise(auth, app_id=APP_ID)
|
||||
# Don't fetch zones because we don't need them yet.
|
||||
user = await api.get_user(fetch_zones=False)
|
||||
except TimeoutError:
|
||||
return on_failure("timeout_connect")
|
||||
except ClientError as ex:
|
||||
LOGGER.error("Unable to connect to Hydrawise cloud service: %s", ex)
|
||||
return on_failure("cannot_connect")
|
||||
|
||||
await self.async_set_unique_id(f"hydrawise-{user.customer_id}")
|
||||
|
||||
if self.source != SOURCE_REAUTH:
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(
|
||||
title="Hydrawise",
|
||||
data={CONF_USERNAME: username, CONF_PASSWORD: password},
|
||||
)
|
||||
|
||||
return self.async_update_reload_and_abort(
|
||||
self._get_reauth_entry(),
|
||||
data_updates={CONF_USERNAME: username, CONF_PASSWORD: password},
|
||||
)
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial setup."""
|
||||
if user_input is not None:
|
||||
username = user_input[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
if user_input is None:
|
||||
return self._show_user_form({})
|
||||
username = user_input[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
unique_id, errors = await _authenticate(username, password)
|
||||
if errors:
|
||||
return self._show_user_form(errors)
|
||||
await self.async_set_unique_id(unique_id)
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(
|
||||
title=username,
|
||||
data={CONF_USERNAME: username, CONF_PASSWORD: password},
|
||||
)
|
||||
|
||||
return await self._create_or_update_entry(
|
||||
username=username, password=password, on_failure=self._show_form
|
||||
)
|
||||
return self._show_form()
|
||||
|
||||
def _show_form(self, error_type: str | None = None) -> ConfigFlowResult:
|
||||
errors = {}
|
||||
if error_type is not None:
|
||||
errors["base"] = error_type
|
||||
def _show_user_form(self, errors: dict[str, str]) -> ConfigFlowResult:
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
{vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str}
|
||||
),
|
||||
errors=errors,
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
async def async_step_reauth(
|
||||
self, entry_data: Mapping[str, Any]
|
||||
) -> ConfigFlowResult:
|
||||
"""Perform reauth after updating config to username/password."""
|
||||
return await self.async_step_user()
|
||||
"""Handle reauth upon an API authentication error."""
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Dialog that informs the user that reauth is required."""
|
||||
if user_input is None:
|
||||
return self._show_reauth_form({})
|
||||
|
||||
reauth_entry = self._get_reauth_entry()
|
||||
username = reauth_entry.data[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
user_id, errors = await _authenticate(username, password)
|
||||
if user_id is None:
|
||||
return self._show_reauth_form(errors)
|
||||
|
||||
await self.async_set_unique_id(user_id)
|
||||
self._abort_if_unique_id_mismatch(reason="wrong_account")
|
||||
return self.async_update_reload_and_abort(
|
||||
reauth_entry, data={CONF_USERNAME: username, CONF_PASSWORD: password}
|
||||
)
|
||||
|
||||
def _show_reauth_form(self, errors: dict[str, str]) -> ConfigFlowResult:
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm", data_schema=STEP_REAUTH_DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
|
||||
async def _authenticate(
|
||||
username: str, password: str
|
||||
) -> tuple[str | None, dict[str, str]]:
|
||||
"""Authenticate with the Hydrawise API."""
|
||||
unique_id = None
|
||||
errors: dict[str, str] = {}
|
||||
auth = pydrawise_auth.Auth(username, password)
|
||||
try:
|
||||
await auth.token()
|
||||
except NotAuthorizedError:
|
||||
errors["base"] = "invalid_auth"
|
||||
except TimeoutError:
|
||||
errors["base"] = "timeout_connect"
|
||||
|
||||
if errors:
|
||||
return unique_id, errors
|
||||
|
||||
try:
|
||||
api = client.Hydrawise(auth, app_id=APP_ID)
|
||||
# Don't fetch zones because we don't need them yet.
|
||||
user = await api.get_user(fetch_zones=False)
|
||||
except TimeoutError:
|
||||
errors["base"] = "timeout_connect"
|
||||
except ClientError as ex:
|
||||
LOGGER.error("Unable to connect to Hydrawise cloud service: %s", ex)
|
||||
errors["base"] = "cannot_connect"
|
||||
else:
|
||||
unique_id = f"hydrawise-{user.customer_id}"
|
||||
|
||||
return unique_id, errors
|
||||
|
@ -8,6 +8,13 @@
|
||||
"username": "[%key:common::config_flow::data::username%]",
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
}
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "[%key:common::config_flow::title::reauth%]",
|
||||
"description": "The Hydrawise integration needs to re-authenticate your account",
|
||||
"data": {
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
|
@ -9,7 +9,7 @@ import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.hydrawise.const import DOMAIN
|
||||
from homeassistant.const import CONF_API_KEY, CONF_PASSWORD, CONF_USERNAME
|
||||
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
@ -33,16 +33,16 @@ async def test_form(
|
||||
assert result["step_id"] == "user"
|
||||
assert result["errors"] == {}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_USERNAME: "asdf@asdf.com", CONF_PASSWORD: "__password__"},
|
||||
)
|
||||
mock_pydrawise.get_user.return_value = user
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Hydrawise"
|
||||
assert result2["data"] == {
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "asdf@asdf.com"
|
||||
assert result["data"] == {
|
||||
CONF_USERNAME: "asdf@asdf.com",
|
||||
CONF_PASSWORD: "__password__",
|
||||
}
|
||||
@ -69,14 +69,14 @@ async def test_form_api_error(
|
||||
|
||||
mock_pydrawise.get_user.reset_mock(side_effect=True)
|
||||
mock_pydrawise.get_user.return_value = user
|
||||
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_form_auth_connect_timeout(
|
||||
hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock
|
||||
) -> None:
|
||||
"""Test we handle API errors."""
|
||||
"""Test we handle connection timeout errors."""
|
||||
mock_auth.token.side_effect = TimeoutError
|
||||
init_result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
@ -90,8 +90,8 @@ async def test_form_auth_connect_timeout(
|
||||
assert result["errors"] == {"base": "timeout_connect"}
|
||||
|
||||
mock_auth.token.reset_mock(side_effect=True)
|
||||
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_form_client_connect_timeout(
|
||||
@ -112,8 +112,8 @@ async def test_form_client_connect_timeout(
|
||||
|
||||
mock_pydrawise.get_user.reset_mock(side_effect=True)
|
||||
mock_pydrawise.get_user.return_value = user
|
||||
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_form_not_authorized_error(
|
||||
@ -133,8 +133,8 @@ async def test_form_not_authorized_error(
|
||||
assert result["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
mock_auth.token.reset_mock(side_effect=True)
|
||||
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], data)
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
|
||||
|
||||
async def test_reauth(
|
||||
@ -148,7 +148,8 @@ async def test_reauth(
|
||||
title="Hydrawise",
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_API_KEY: "__api_key__",
|
||||
CONF_USERNAME: "asdf@asdf.com",
|
||||
CONF_PASSWORD: "bad-password",
|
||||
},
|
||||
unique_id="hydrawise-12345",
|
||||
)
|
||||
@ -160,14 +161,49 @@ async def test_reauth(
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
[result] = flows
|
||||
assert result["step_id"] == "user"
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_USERNAME: "asdf@asdf.com", CONF_PASSWORD: "__password__"},
|
||||
)
|
||||
mock_pydrawise.get_user.return_value = user
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_PASSWORD: "__password__"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.ABORT
|
||||
assert result2["reason"] == "reauth_successful"
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
|
||||
|
||||
async def test_reauth_fails(
|
||||
hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock, user: User
|
||||
) -> None:
|
||||
"""Test that the reauth flow handles API errors."""
|
||||
mock_config_entry = MockConfigEntry(
|
||||
title="Hydrawise",
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_USERNAME: "asdf@asdf.com",
|
||||
CONF_PASSWORD: "bad-password",
|
||||
},
|
||||
unique_id="hydrawise-12345",
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
result = await mock_config_entry.start_reauth_flow(hass)
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
mock_auth.token.side_effect = NotAuthorizedError
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_PASSWORD: "__password__"}
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
mock_auth.token.reset_mock(side_effect=True)
|
||||
mock_pydrawise.get_user.return_value = user
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_PASSWORD: "__password__"}
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
|
Loading…
x
Reference in New Issue
Block a user