From 4c73826baf25da5b23b544490f9cb0fd0a5ba000 Mon Sep 17 00:00:00 2001 From: Aaron Bach Date: Sun, 18 Dec 2022 11:00:08 -0700 Subject: [PATCH] Add re-auth flow to AirVisual Pro (#84012) * Add re-auth flow to AirVisual Pro * Code review --- .../components/airvisual_pro/__init__.py | 12 +- .../components/airvisual_pro/config_flow.py | 114 ++++++++++++----- .../components/airvisual_pro/strings.json | 10 +- .../airvisual_pro/translations/en.json | 11 +- tests/components/airvisual_pro/conftest.py | 33 ++++- .../airvisual_pro/test_config_flow.py | 120 ++++++++++++------ 6 files changed, 218 insertions(+), 82 deletions(-) diff --git a/homeassistant/components/airvisual_pro/__init__.py b/homeassistant/components/airvisual_pro/__init__.py index 8255019f14e..b745dea1d94 100644 --- a/homeassistant/components/airvisual_pro/__init__.py +++ b/homeassistant/components/airvisual_pro/__init__.py @@ -7,8 +7,12 @@ from dataclasses import dataclass from datetime import timedelta from typing import Any -from pyairvisual import NodeSamba -from pyairvisual.node import NodeConnectionError, NodeProError +from pyairvisual.node import ( + InvalidAuthenticationError, + NodeConnectionError, + NodeProError, + NodeSamba, +) from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( @@ -18,7 +22,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers.entity import DeviceInfo, EntityDescription from homeassistant.helpers.update_coordinator import ( CoordinatorEntity, @@ -56,6 +60,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Get data from the device.""" try: data = await node.async_get_latest_measurements() + except InvalidAuthenticationError as err: + raise ConfigEntryAuthFailed("Invalid Samba password") from err except NodeConnectionError as err: nonlocal reload_task if not reload_task: diff --git a/homeassistant/components/airvisual_pro/config_flow.py b/homeassistant/components/airvisual_pro/config_flow.py index 85e03eec504..8fa588ec700 100644 --- a/homeassistant/components/airvisual_pro/config_flow.py +++ b/homeassistant/components/airvisual_pro/config_flow.py @@ -1,16 +1,30 @@ """Define a config flow manager for AirVisual Pro.""" from __future__ import annotations -from pyairvisual import NodeSamba -from pyairvisual.node import NodeProError +from collections.abc import Mapping +from typing import Any + +from pyairvisual.node import ( + InvalidAuthenticationError, + NodeConnectionError, + NodeProError, + NodeSamba, +) import voluptuous as vol from homeassistant import config_entries +from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD from homeassistant.data_entry_flow import FlowResult from .const import DOMAIN, LOGGER +STEP_REAUTH_SCHEMA = vol.Schema( + { + vol.Required(CONF_PASSWORD): str, + } +) + STEP_USER_SCHEMA = vol.Schema( { vol.Required(CONF_IP_ADDRESS): str, @@ -19,11 +33,73 @@ STEP_USER_SCHEMA = vol.Schema( ) +async def async_validate_credentials(ip_address: str, password: str) -> dict[str, Any]: + """Validate an IP address/password combo (and return any errors as appropriate).""" + node = NodeSamba(ip_address, password) + errors = {} + + try: + await node.async_connect() + except InvalidAuthenticationError as err: + LOGGER.error("Invalid password for Pro at IP address %s: %s", ip_address, err) + errors["base"] = "invalid_auth" + except NodeConnectionError as err: + LOGGER.error("Cannot connect to Pro at IP address %s: %s", ip_address, err) + errors["base"] = "cannot_connect" + except NodeProError as err: + LOGGER.error("Unknown Pro error while connecting to %s: %s", ip_address, err) + errors["base"] = "unknown" + except Exception as err: # pylint: disable=broad-except + LOGGER.exception("Unknown error while connecting to %s: %s", ip_address, err) + errors["base"] = "unknown" + finally: + await node.async_disconnect() + + return errors + + class AirVisualProFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle an AirVisual Pro config flow.""" VERSION = 1 + def __init__(self) -> None: + """Initialize.""" + self._reauth_entry: ConfigEntry | None = None + + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: + """Handle configuration by re-auth.""" + self._reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle the re-auth step.""" + if user_input is None: + return self.async_show_form( + step_id="reauth_confirm", data_schema=STEP_REAUTH_SCHEMA + ) + + assert self._reauth_entry + + if errors := await async_validate_credentials( + self._reauth_entry.data[CONF_IP_ADDRESS], user_input[CONF_PASSWORD] + ): + return self.async_show_form( + step_id="reauth_confirm", data_schema=STEP_REAUTH_SCHEMA, errors=errors + ) + + self.hass.config_entries.async_update_entry( + self._reauth_entry, data=self._reauth_entry.data | user_input + ) + self.hass.async_create_task( + self.hass.config_entries.async_reload(self._reauth_entry.entry_id) + ) + return self.async_abort(reason="reauth_successful") + async def async_step_user( self, user_input: dict[str, str] | None = None ) -> FlowResult: @@ -31,36 +107,16 @@ class AirVisualProFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if not user_input: return self.async_show_form(step_id="user", data_schema=STEP_USER_SCHEMA) - await self.async_set_unique_id(user_input[CONF_IP_ADDRESS]) + ip_address = user_input[CONF_IP_ADDRESS] + + await self.async_set_unique_id(ip_address) self._abort_if_unique_id_configured() - errors = {} - node = NodeSamba(user_input[CONF_IP_ADDRESS], user_input[CONF_PASSWORD]) - - try: - await node.async_connect() - except NodeProError as err: - LOGGER.error( - "Samba error while connecting to %s: %s", - user_input[CONF_IP_ADDRESS], - err, - ) - errors["base"] = "cannot_connect" - except Exception as err: # pylint: disable=broad-except - LOGGER.error( - "Unknown error while connecting to %s: %s", - user_input[CONF_IP_ADDRESS], - err, - ) - errors["base"] = "unknown" - finally: - await node.async_disconnect() - - if errors: + if errors := await async_validate_credentials( + ip_address, user_input[CONF_PASSWORD] + ): return self.async_show_form( step_id="user", data_schema=STEP_USER_SCHEMA, errors=errors ) - return self.async_create_entry( - title=user_input[CONF_IP_ADDRESS], data=user_input - ) + return self.async_create_entry(title=ip_address, data=user_input) diff --git a/homeassistant/components/airvisual_pro/strings.json b/homeassistant/components/airvisual_pro/strings.json index 2349b7cb69f..f06f120885e 100644 --- a/homeassistant/components/airvisual_pro/strings.json +++ b/homeassistant/components/airvisual_pro/strings.json @@ -1,6 +1,12 @@ { "config": { "step": { + "reauth_confirm": { + "description": "[%key:component::airvisual_pro::config::step::user::description%]", + "data": { + "password": "[%key:common::config_flow::data::password%]" + } + }, "user": { "description": "The password can be retrieved from the AirVisual Pro's UI.", "data": { @@ -10,11 +16,13 @@ } }, "error": { + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "unknown": "[%key:common::config_flow::error::unknown%]" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } } } diff --git a/homeassistant/components/airvisual_pro/translations/en.json b/homeassistant/components/airvisual_pro/translations/en.json index ac54d4d2f09..d8f22590b41 100644 --- a/homeassistant/components/airvisual_pro/translations/en.json +++ b/homeassistant/components/airvisual_pro/translations/en.json @@ -1,13 +1,22 @@ { "config": { "abort": { - "already_configured": "Device is already configured" + "already_configured": "Device is already configured", + "reauth_successful": "Re-authentication was successful" }, "error": { "cannot_connect": "Failed to connect", + "invalid_auth": "Invalid authentication", "unknown": "Unexpected error" }, "step": { + "reauth_confirm": { + "data": { + "ip_address": "Host", + "password": "Password" + }, + "description": "The password can be retrieved from the AirVisual Pro's UI." + }, "user": { "data": { "ip_address": "Host", diff --git a/tests/components/airvisual_pro/conftest.py b/tests/components/airvisual_pro/conftest.py index 86fbdc89224..c5851e940de 100644 --- a/tests/components/airvisual_pro/conftest.py +++ b/tests/components/airvisual_pro/conftest.py @@ -1,6 +1,6 @@ """Define test fixtures for AirVisual Pro.""" import json -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -28,20 +28,41 @@ def config_fixture(hass): } +@pytest.fixture(name="connect") +def connect_fixture(): + """Define a mocked async_connect method.""" + return AsyncMock(return_value=True) + + +@pytest.fixture(name="disconnect") +def disconnect_fixture(): + """Define a mocked async_connect method.""" + return AsyncMock() + + @pytest.fixture(name="data", scope="session") def data_fixture(): """Define an update coordinator data example.""" return json.loads(load_fixture("data.json", "airvisual_pro")) +@pytest.fixture(name="pro") +def pro_fixture(connect, data, disconnect): + """Define a mocked NodeSamba object.""" + return Mock( + async_connect=connect, + async_disconnect=disconnect, + async_get_latest_measurements=AsyncMock(return_value=data), + ) + + @pytest.fixture(name="setup_airvisual_pro") -async def setup_airvisual_pro_fixture(hass, config, data): +async def setup_airvisual_pro_fixture(hass, config, pro): """Define a fixture to set up AirVisual Pro.""" - with patch("homeassistant.components.airvisual_pro.NodeSamba.async_connect"), patch( - "homeassistant.components.airvisual_pro.NodeSamba.async_get_latest_measurements", - return_value=data, + with patch( + "homeassistant.components.airvisual_pro.config_flow.NodeSamba", return_value=pro ), patch( - "homeassistant.components.airvisual_pro.NodeSamba.async_disconnect" + "homeassistant.components.airvisual_pro.NodeSamba", return_value=pro ), patch( "homeassistant.components.airvisual.PLATFORMS", [] ): diff --git a/tests/components/airvisual_pro/test_config_flow.py b/tests/components/airvisual_pro/test_config_flow.py index f9114c29868..7ae9bbe44ab 100644 --- a/tests/components/airvisual_pro/test_config_flow.py +++ b/tests/components/airvisual_pro/test_config_flow.py @@ -1,15 +1,57 @@ """Test the AirVisual Pro config flow.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, patch -from pyairvisual.node import NodeProError +from pyairvisual.node import ( + InvalidAuthenticationError, + NodeConnectionError, + NodeProError, +) import pytest from homeassistant import data_entry_flow from homeassistant.components.airvisual_pro.const import DOMAIN -from homeassistant.config_entries import SOURCE_USER +from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD +@pytest.mark.parametrize( + "connect_mock,connect_errors", + [ + (AsyncMock(side_effect=Exception), {"base": "unknown"}), + (AsyncMock(side_effect=InvalidAuthenticationError), {"base": "invalid_auth"}), + (AsyncMock(side_effect=NodeConnectionError), {"base": "cannot_connect"}), + (AsyncMock(side_effect=NodeProError), {"base": "unknown"}), + ], +) +async def test_create_entry( + hass, config, connect_errors, connect_mock, pro, setup_airvisual_pro +): + """Test creating an entry.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "user" + + # Test errors that can arise when connecting to a Pro: + with patch.object(pro, "async_connect", connect_mock): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=config + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["errors"] == connect_errors + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=config + ) + assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + assert result["title"] == "192.168.1.101" + assert result["data"] == { + CONF_IP_ADDRESS: "192.168.1.101", + CONF_PASSWORD: "password123", + } + + async def test_duplicate_error(hass, config, config_entry): """Test that errors are shown when duplicates are added.""" result = await hass.config_entries.flow.async_init( @@ -20,51 +62,45 @@ async def test_duplicate_error(hass, config, config_entry): @pytest.mark.parametrize( - "exc,errors", + "connect_mock,connect_errors", [ - (NodeProError, {"base": "cannot_connect"}), - (Exception, {"base": "unknown"}), + (AsyncMock(side_effect=Exception), {"base": "unknown"}), + (AsyncMock(side_effect=InvalidAuthenticationError), {"base": "invalid_auth"}), + (AsyncMock(side_effect=NodeConnectionError), {"base": "cannot_connect"}), + (AsyncMock(side_effect=NodeProError), {"base": "unknown"}), ], ) -async def test_errors(hass, config, exc, errors, setup_airvisual_pro): - """Test that an exceptions show an error.""" - with patch( - "homeassistant.components.airvisual_pro.config_flow.NodeSamba.async_connect", - side_effect=exc, - ): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=config - ) - assert result["type"] == data_entry_flow.FlowResultType.FORM - assert result["errors"] == errors - - # Validate that we can still proceed after an error if the underlying condition - # resolves: - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input=config - ) - assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY - assert result["title"] == "192.168.1.101" - assert result["data"] == { - CONF_IP_ADDRESS: "192.168.1.101", - CONF_PASSWORD: "password123", - } - - -async def test_step_user(hass, config, setup_airvisual_pro): - """Test that the user step works.""" +async def test_reauth( + hass, config, config_entry, connect_errors, connect_mock, pro, setup_airvisual_pro +): + """Test re-auth (including errors).""" result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER} + DOMAIN, + context={ + "source": SOURCE_REAUTH, + "entry_id": config_entry.entry_id, + "unique_id": config_entry.unique_id, + }, + data=config, ) assert result["type"] == data_entry_flow.FlowResultType.FORM - assert result["step_id"] == "user" + assert result["step_id"] == "reauth_confirm" + + # Test errors that can arise when connecting to a Pro: + with patch.object(pro, "async_connect", connect_mock): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_PASSWORD: "new_password"} + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["errors"] == connect_errors result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input=config + result["flow_id"], user_input={CONF_PASSWORD: "new_password"} ) - assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY - assert result["title"] == "192.168.1.101" - assert result["data"] == { - CONF_IP_ADDRESS: "192.168.1.101", - CONF_PASSWORD: "password123", - } + + # Allow reload to finish: + await hass.async_block_till_done() + + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert len(hass.config_entries.async_entries()) == 1