mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Simplify Whirlpool auth flows (#136856)
This commit is contained in:
parent
89e6791fee
commit
a61399f189
@ -15,7 +15,6 @@ from whirlpool.backendselector import BackendSelector
|
|||||||
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
||||||
from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_USERNAME
|
from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_USERNAME
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
|
|
||||||
from .const import CONF_BRAND, CONF_BRANDS_MAP, CONF_REGIONS_MAP, DOMAIN
|
from .const import CONF_BRAND, CONF_BRANDS_MAP, CONF_REGIONS_MAP, DOMAIN
|
||||||
@ -40,31 +39,39 @@ REAUTH_SCHEMA = vol.Schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def validate_input(hass: HomeAssistant, data: dict[str, str]) -> dict[str, str]:
|
async def authenticate(
|
||||||
"""Validate the user input allows us to connect.
|
hass: HomeAssistant, data: dict[str, str], check_appliances_exist: bool
|
||||||
|
) -> str | None:
|
||||||
|
"""Authenticate with the api.
|
||||||
|
|
||||||
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
||||||
|
Returns the error translation key if authentication fails, or None on success.
|
||||||
"""
|
"""
|
||||||
session = async_get_clientsession(hass)
|
session = async_get_clientsession(hass)
|
||||||
region = CONF_REGIONS_MAP[data[CONF_REGION]]
|
region = CONF_REGIONS_MAP[data[CONF_REGION]]
|
||||||
brand = CONF_BRANDS_MAP[data[CONF_BRAND]]
|
brand = CONF_BRANDS_MAP[data[CONF_BRAND]]
|
||||||
backend_selector = BackendSelector(brand, region)
|
backend_selector = BackendSelector(brand, region)
|
||||||
auth = Auth(backend_selector, data[CONF_USERNAME], data[CONF_PASSWORD], session)
|
auth = Auth(backend_selector, data[CONF_USERNAME], data[CONF_PASSWORD], session)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await auth.do_auth()
|
await auth.do_auth()
|
||||||
except (TimeoutError, ClientError) as exc:
|
except (TimeoutError, ClientError):
|
||||||
raise CannotConnect from exc
|
return "cannot_connect"
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception("Unexpected exception")
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
if not auth.is_access_token_valid():
|
if not auth.is_access_token_valid():
|
||||||
raise InvalidAuth
|
return "invalid_auth"
|
||||||
|
|
||||||
appliances_manager = AppliancesManager(backend_selector, auth, session)
|
if check_appliances_exist:
|
||||||
await appliances_manager.fetch_appliances()
|
appliances_manager = AppliancesManager(backend_selector, auth, session)
|
||||||
|
await appliances_manager.fetch_appliances()
|
||||||
|
|
||||||
if not appliances_manager.aircons and not appliances_manager.washer_dryers:
|
if not appliances_manager.aircons and not appliances_manager.washer_dryers:
|
||||||
raise NoAppliances
|
return "no_appliances"
|
||||||
|
|
||||||
return {"title": data[CONF_USERNAME]}
|
return None
|
||||||
|
|
||||||
|
|
||||||
class WhirlpoolConfigFlow(ConfigFlow, domain=DOMAIN):
|
class WhirlpoolConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
@ -90,14 +97,10 @@ class WhirlpoolConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
brand = user_input[CONF_BRAND]
|
brand = user_input[CONF_BRAND]
|
||||||
data = {**reauth_entry.data, CONF_PASSWORD: password, CONF_BRAND: brand}
|
data = {**reauth_entry.data, CONF_PASSWORD: password, CONF_BRAND: brand}
|
||||||
|
|
||||||
try:
|
error_key = await authenticate(self.hass, data, False)
|
||||||
await validate_input(self.hass, data)
|
if not error_key:
|
||||||
except InvalidAuth:
|
|
||||||
errors["base"] = "invalid_auth"
|
|
||||||
except (CannotConnect, TimeoutError):
|
|
||||||
errors["base"] = "cannot_connect"
|
|
||||||
else:
|
|
||||||
return self.async_update_reload_and_abort(reauth_entry, data=data)
|
return self.async_update_reload_and_abort(reauth_entry, data=data)
|
||||||
|
errors["base"] = error_key
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="reauth_confirm",
|
step_id="reauth_confirm",
|
||||||
@ -113,38 +116,17 @@ class WhirlpoolConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||||
)
|
)
|
||||||
|
|
||||||
errors = {}
|
error_key = await authenticate(self.hass, user_input, True)
|
||||||
|
if not error_key:
|
||||||
try:
|
|
||||||
info = await validate_input(self.hass, user_input)
|
|
||||||
except CannotConnect:
|
|
||||||
errors["base"] = "cannot_connect"
|
|
||||||
except InvalidAuth:
|
|
||||||
errors["base"] = "invalid_auth"
|
|
||||||
except NoAppliances:
|
|
||||||
errors["base"] = "no_appliances"
|
|
||||||
except Exception:
|
|
||||||
_LOGGER.exception("Unexpected exception")
|
|
||||||
errors["base"] = "unknown"
|
|
||||||
else:
|
|
||||||
await self.async_set_unique_id(
|
await self.async_set_unique_id(
|
||||||
user_input[CONF_USERNAME].lower(), raise_on_progress=False
|
user_input[CONF_USERNAME].lower(), raise_on_progress=False
|
||||||
)
|
)
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
return self.async_create_entry(title=info["title"], data=user_input)
|
return self.async_create_entry(
|
||||||
|
title=user_input[CONF_USERNAME], data=user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = {"base": error_key}
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CannotConnect(HomeAssistantError):
|
|
||||||
"""Error to indicate we cannot connect."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidAuth(HomeAssistantError):
|
|
||||||
"""Error to indicate there is invalid auth."""
|
|
||||||
|
|
||||||
|
|
||||||
class NoAppliances(HomeAssistantError):
|
|
||||||
"""Error to indicate no supported appliances in the user account."""
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp.client_exceptions import ClientConnectionError
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
@ -219,7 +218,7 @@ async def test_reauth_flow(hass: HomeAssistant, region, brand) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_appliances_manager_api", "mock_whirlpool_setup_entry")
|
@pytest.mark.usefixtures("mock_appliances_manager_api", "mock_whirlpool_setup_entry")
|
||||||
async def test_reauth_flow_auth_error(
|
async def test_reauth_flow_invalid_auth(
|
||||||
hass: HomeAssistant, region, brand, mock_auth_api: MagicMock
|
hass: HomeAssistant, region, brand, mock_auth_api: MagicMock
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test an authorization error reauth flow."""
|
"""Test an authorization error reauth flow."""
|
||||||
@ -247,8 +246,21 @@ async def test_reauth_flow_auth_error(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_appliances_manager_api", "mock_whirlpool_setup_entry")
|
@pytest.mark.usefixtures("mock_appliances_manager_api", "mock_whirlpool_setup_entry")
|
||||||
async def test_reauth_flow_connnection_error(
|
@pytest.mark.parametrize(
|
||||||
hass: HomeAssistant, region, brand, mock_auth_api: MagicMock
|
("exception", "expected_error"),
|
||||||
|
[
|
||||||
|
(aiohttp.ClientConnectionError, "cannot_connect"),
|
||||||
|
(TimeoutError, "cannot_connect"),
|
||||||
|
(Exception, "unknown"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_reauth_flow_auth_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
exception: Exception,
|
||||||
|
expected_error: str,
|
||||||
|
region,
|
||||||
|
brand,
|
||||||
|
mock_auth_api: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test a connection error reauth flow."""
|
"""Test a connection error reauth flow."""
|
||||||
|
|
||||||
@ -265,11 +277,10 @@ async def test_reauth_flow_connnection_error(
|
|||||||
assert result["type"] is FlowResultType.FORM
|
assert result["type"] is FlowResultType.FORM
|
||||||
assert result["errors"] == {}
|
assert result["errors"] == {}
|
||||||
|
|
||||||
mock_auth_api.return_value.do_auth.side_effect = ClientConnectionError
|
mock_auth_api.return_value.do_auth.side_effect = exception
|
||||||
result2 = await hass.config_entries.flow.async_configure(
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{CONF_PASSWORD: "new-password", CONF_BRAND: brand[0]},
|
{CONF_PASSWORD: "new-password", CONF_BRAND: brand[0]},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result2["type"] is FlowResultType.FORM
|
assert result2["type"] is FlowResultType.FORM
|
||||||
assert result2["errors"] == {"base": "cannot_connect"}
|
assert result2["errors"] == {"base": expected_error}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user