Add re-auth flow to AirVisual Pro (#84012)

* Add re-auth flow to AirVisual Pro

* Code review
This commit is contained in:
Aaron Bach 2022-12-18 11:00:08 -07:00 committed by GitHub
parent 47522546e6
commit 4c73826baf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 218 additions and 82 deletions

View File

@ -7,8 +7,12 @@ from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from pyairvisual import NodeSamba from pyairvisual.node import (
from pyairvisual.node import NodeConnectionError, NodeProError InvalidAuthenticationError,
NodeConnectionError,
NodeProError,
NodeSamba,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
@ -18,7 +22,7 @@ from homeassistant.const import (
Platform, Platform,
) )
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers.entity import DeviceInfo, EntityDescription from homeassistant.helpers.entity import DeviceInfo, EntityDescription
from homeassistant.helpers.update_coordinator import ( from homeassistant.helpers.update_coordinator import (
CoordinatorEntity, CoordinatorEntity,
@ -56,6 +60,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Get data from the device.""" """Get data from the device."""
try: try:
data = await node.async_get_latest_measurements() data = await node.async_get_latest_measurements()
except InvalidAuthenticationError as err:
raise ConfigEntryAuthFailed("Invalid Samba password") from err
except NodeConnectionError as err: except NodeConnectionError as err:
nonlocal reload_task nonlocal reload_task
if not reload_task: if not reload_task:

View File

@ -1,16 +1,30 @@
"""Define a config flow manager for AirVisual Pro.""" """Define a config flow manager for AirVisual Pro."""
from __future__ import annotations from __future__ import annotations
from pyairvisual import NodeSamba from collections.abc import Mapping
from pyairvisual.node import NodeProError from typing import Any
from pyairvisual.node import (
InvalidAuthenticationError,
NodeConnectionError,
NodeProError,
NodeSamba,
)
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from .const import DOMAIN, LOGGER from .const import DOMAIN, LOGGER
STEP_REAUTH_SCHEMA = vol.Schema(
{
vol.Required(CONF_PASSWORD): str,
}
)
STEP_USER_SCHEMA = vol.Schema( STEP_USER_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_IP_ADDRESS): str, vol.Required(CONF_IP_ADDRESS): str,
@ -19,11 +33,73 @@ STEP_USER_SCHEMA = vol.Schema(
) )
async def async_validate_credentials(ip_address: str, password: str) -> dict[str, Any]:
"""Validate an IP address/password combo (and return any errors as appropriate)."""
node = NodeSamba(ip_address, password)
errors = {}
try:
await node.async_connect()
except InvalidAuthenticationError as err:
LOGGER.error("Invalid password for Pro at IP address %s: %s", ip_address, err)
errors["base"] = "invalid_auth"
except NodeConnectionError as err:
LOGGER.error("Cannot connect to Pro at IP address %s: %s", ip_address, err)
errors["base"] = "cannot_connect"
except NodeProError as err:
LOGGER.error("Unknown Pro error while connecting to %s: %s", ip_address, err)
errors["base"] = "unknown"
except Exception as err: # pylint: disable=broad-except
LOGGER.exception("Unknown error while connecting to %s: %s", ip_address, err)
errors["base"] = "unknown"
finally:
await node.async_disconnect()
return errors
class AirVisualProFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): class AirVisualProFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle an AirVisual Pro config flow.""" """Handle an AirVisual Pro config flow."""
VERSION = 1 VERSION = 1
def __init__(self) -> None:
"""Initialize."""
self._reauth_entry: ConfigEntry | None = None
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle configuration by re-auth."""
self._reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the re-auth step."""
if user_input is None:
return self.async_show_form(
step_id="reauth_confirm", data_schema=STEP_REAUTH_SCHEMA
)
assert self._reauth_entry
if errors := await async_validate_credentials(
self._reauth_entry.data[CONF_IP_ADDRESS], user_input[CONF_PASSWORD]
):
return self.async_show_form(
step_id="reauth_confirm", data_schema=STEP_REAUTH_SCHEMA, errors=errors
)
self.hass.config_entries.async_update_entry(
self._reauth_entry, data=self._reauth_entry.data | user_input
)
self.hass.async_create_task(
self.hass.config_entries.async_reload(self._reauth_entry.entry_id)
)
return self.async_abort(reason="reauth_successful")
async def async_step_user( async def async_step_user(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> FlowResult:
@ -31,36 +107,16 @@ class AirVisualProFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if not user_input: if not user_input:
return self.async_show_form(step_id="user", data_schema=STEP_USER_SCHEMA) return self.async_show_form(step_id="user", data_schema=STEP_USER_SCHEMA)
await self.async_set_unique_id(user_input[CONF_IP_ADDRESS]) ip_address = user_input[CONF_IP_ADDRESS]
await self.async_set_unique_id(ip_address)
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
errors = {} if errors := await async_validate_credentials(
node = NodeSamba(user_input[CONF_IP_ADDRESS], user_input[CONF_PASSWORD]) ip_address, user_input[CONF_PASSWORD]
):
try:
await node.async_connect()
except NodeProError as err:
LOGGER.error(
"Samba error while connecting to %s: %s",
user_input[CONF_IP_ADDRESS],
err,
)
errors["base"] = "cannot_connect"
except Exception as err: # pylint: disable=broad-except
LOGGER.error(
"Unknown error while connecting to %s: %s",
user_input[CONF_IP_ADDRESS],
err,
)
errors["base"] = "unknown"
finally:
await node.async_disconnect()
if errors:
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_SCHEMA, errors=errors step_id="user", data_schema=STEP_USER_SCHEMA, errors=errors
) )
return self.async_create_entry( return self.async_create_entry(title=ip_address, data=user_input)
title=user_input[CONF_IP_ADDRESS], data=user_input
)

View File

@ -1,6 +1,12 @@
{ {
"config": { "config": {
"step": { "step": {
"reauth_confirm": {
"description": "[%key:component::airvisual_pro::config::step::user::description%]",
"data": {
"password": "[%key:common::config_flow::data::password%]"
}
},
"user": { "user": {
"description": "The password can be retrieved from the AirVisual Pro's UI.", "description": "The password can be retrieved from the AirVisual Pro's UI.",
"data": { "data": {
@ -10,11 +16,13 @@
} }
}, },
"error": { "error": {
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]" "already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
} }
} }

View File

@ -1,13 +1,22 @@
{ {
"config": { "config": {
"abort": { "abort": {
"already_configured": "Device is already configured" "already_configured": "Device is already configured",
"reauth_successful": "Re-authentication was successful"
}, },
"error": { "error": {
"cannot_connect": "Failed to connect", "cannot_connect": "Failed to connect",
"invalid_auth": "Invalid authentication",
"unknown": "Unexpected error" "unknown": "Unexpected error"
}, },
"step": { "step": {
"reauth_confirm": {
"data": {
"ip_address": "Host",
"password": "Password"
},
"description": "The password can be retrieved from the AirVisual Pro's UI."
},
"user": { "user": {
"data": { "data": {
"ip_address": "Host", "ip_address": "Host",

View File

@ -1,6 +1,6 @@
"""Define test fixtures for AirVisual Pro.""" """Define test fixtures for AirVisual Pro."""
import json import json
from unittest.mock import patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@ -28,20 +28,41 @@ def config_fixture(hass):
} }
@pytest.fixture(name="connect")
def connect_fixture():
"""Define a mocked async_connect method."""
return AsyncMock(return_value=True)
@pytest.fixture(name="disconnect")
def disconnect_fixture():
"""Define a mocked async_connect method."""
return AsyncMock()
@pytest.fixture(name="data", scope="session") @pytest.fixture(name="data", scope="session")
def data_fixture(): def data_fixture():
"""Define an update coordinator data example.""" """Define an update coordinator data example."""
return json.loads(load_fixture("data.json", "airvisual_pro")) return json.loads(load_fixture("data.json", "airvisual_pro"))
@pytest.fixture(name="pro")
def pro_fixture(connect, data, disconnect):
"""Define a mocked NodeSamba object."""
return Mock(
async_connect=connect,
async_disconnect=disconnect,
async_get_latest_measurements=AsyncMock(return_value=data),
)
@pytest.fixture(name="setup_airvisual_pro") @pytest.fixture(name="setup_airvisual_pro")
async def setup_airvisual_pro_fixture(hass, config, data): async def setup_airvisual_pro_fixture(hass, config, pro):
"""Define a fixture to set up AirVisual Pro.""" """Define a fixture to set up AirVisual Pro."""
with patch("homeassistant.components.airvisual_pro.NodeSamba.async_connect"), patch( with patch(
"homeassistant.components.airvisual_pro.NodeSamba.async_get_latest_measurements", "homeassistant.components.airvisual_pro.config_flow.NodeSamba", return_value=pro
return_value=data,
), patch( ), patch(
"homeassistant.components.airvisual_pro.NodeSamba.async_disconnect" "homeassistant.components.airvisual_pro.NodeSamba", return_value=pro
), patch( ), patch(
"homeassistant.components.airvisual.PLATFORMS", [] "homeassistant.components.airvisual.PLATFORMS", []
): ):

View File

@ -1,15 +1,57 @@
"""Test the AirVisual Pro config flow.""" """Test the AirVisual Pro config flow."""
from unittest.mock import patch from unittest.mock import AsyncMock, patch
from pyairvisual.node import NodeProError from pyairvisual.node import (
InvalidAuthenticationError,
NodeConnectionError,
NodeProError,
)
import pytest import pytest
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.components.airvisual_pro.const import DOMAIN from homeassistant.components.airvisual_pro.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER
from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD from homeassistant.const import CONF_IP_ADDRESS, CONF_PASSWORD
@pytest.mark.parametrize(
"connect_mock,connect_errors",
[
(AsyncMock(side_effect=Exception), {"base": "unknown"}),
(AsyncMock(side_effect=InvalidAuthenticationError), {"base": "invalid_auth"}),
(AsyncMock(side_effect=NodeConnectionError), {"base": "cannot_connect"}),
(AsyncMock(side_effect=NodeProError), {"base": "unknown"}),
],
)
async def test_create_entry(
hass, config, connect_errors, connect_mock, pro, setup_airvisual_pro
):
"""Test creating an entry."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "user"
# Test errors that can arise when connecting to a Pro:
with patch.object(pro, "async_connect", connect_mock):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=config
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["errors"] == connect_errors
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=config
)
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "192.168.1.101"
assert result["data"] == {
CONF_IP_ADDRESS: "192.168.1.101",
CONF_PASSWORD: "password123",
}
async def test_duplicate_error(hass, config, config_entry): async def test_duplicate_error(hass, config, config_entry):
"""Test that errors are shown when duplicates are added.""" """Test that errors are shown when duplicates are added."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -20,51 +62,45 @@ async def test_duplicate_error(hass, config, config_entry):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"exc,errors", "connect_mock,connect_errors",
[ [
(NodeProError, {"base": "cannot_connect"}), (AsyncMock(side_effect=Exception), {"base": "unknown"}),
(Exception, {"base": "unknown"}), (AsyncMock(side_effect=InvalidAuthenticationError), {"base": "invalid_auth"}),
(AsyncMock(side_effect=NodeConnectionError), {"base": "cannot_connect"}),
(AsyncMock(side_effect=NodeProError), {"base": "unknown"}),
], ],
) )
async def test_errors(hass, config, exc, errors, setup_airvisual_pro): async def test_reauth(
"""Test that an exceptions show an error.""" hass, config, config_entry, connect_errors, connect_mock, pro, setup_airvisual_pro
with patch( ):
"homeassistant.components.airvisual_pro.config_flow.NodeSamba.async_connect", """Test re-auth (including errors)."""
side_effect=exc,
):
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}, data=config
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["errors"] == errors
# Validate that we can still proceed after an error if the underlying condition
# resolves:
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=config
)
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "192.168.1.101"
assert result["data"] == {
CONF_IP_ADDRESS: "192.168.1.101",
CONF_PASSWORD: "password123",
}
async def test_step_user(hass, config, setup_airvisual_pro):
"""Test that the user step works."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={
"source": SOURCE_REAUTH,
"entry_id": config_entry.entry_id,
"unique_id": config_entry.unique_id,
},
data=config,
) )
assert result["type"] == data_entry_flow.FlowResultType.FORM assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "reauth_confirm"
# Test errors that can arise when connecting to a Pro:
with patch.object(pro, "async_connect", connect_mock):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_PASSWORD: "new_password"}
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["errors"] == connect_errors
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=config result["flow_id"], user_input={CONF_PASSWORD: "new_password"}
) )
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert result["title"] == "192.168.1.101" # Allow reload to finish:
assert result["data"] == { await hass.async_block_till_done()
CONF_IP_ADDRESS: "192.168.1.101",
CONF_PASSWORD: "password123", assert result["type"] == data_entry_flow.FlowResultType.ABORT
} assert result["reason"] == "reauth_successful"
assert len(hass.config_entries.async_entries()) == 1