Improve / clean up Plugwise config_flow code (#127238)

This commit is contained in:
Bouwe Westerdijk 2024-10-01 21:52:16 +02:00 committed by GitHub
parent dd478fe681
commit 0616bc7fec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 23 deletions

View File

@ -16,8 +16,9 @@ from plugwise.exceptions import (
import voluptuous as vol import voluptuous as vol
from homeassistant.components.zeroconf import ZeroconfServiceInfo from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.config_entries import SOURCE_USER, ConfigFlow, ConfigFlowResult
from homeassistant.const import ( from homeassistant.const import (
ATTR_CONFIGURATION_URL,
CONF_BASE, CONF_BASE,
CONF_HOST, CONF_HOST,
CONF_NAME, CONF_NAME,
@ -29,13 +30,11 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import ( from .const import (
API,
DEFAULT_PORT, DEFAULT_PORT,
DEFAULT_USERNAME, DEFAULT_USERNAME,
DOMAIN, DOMAIN,
FLOW_SMILE, FLOW_SMILE,
FLOW_STRETCH, FLOW_STRETCH,
PW_TYPE,
SMILE, SMILE,
STRETCH, STRETCH,
STRETCH_USERNAME, STRETCH_USERNAME,
@ -43,12 +42,12 @@ from .const import (
) )
def _base_gw_schema(discovery_info: ZeroconfServiceInfo | None) -> vol.Schema: def base_schema(discovery_info: ZeroconfServiceInfo | None) -> vol.Schema:
"""Generate base schema for gateways.""" """Generate base schema for gateways."""
base_gw_schema = vol.Schema({vol.Required(CONF_PASSWORD): str}) schema = vol.Schema({vol.Required(CONF_PASSWORD): str})
if not discovery_info: if not discovery_info:
base_gw_schema = base_gw_schema.extend( schema = schema.extend(
{ {
vol.Required(CONF_HOST): str, vol.Required(CONF_HOST): str,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): int, vol.Optional(CONF_PORT, default=DEFAULT_PORT): int,
@ -58,13 +57,13 @@ def _base_gw_schema(discovery_info: ZeroconfServiceInfo | None) -> vol.Schema:
} }
) )
return base_gw_schema return schema
async def validate_gw_input(hass: HomeAssistant, data: dict[str, Any]) -> Smile: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> Smile:
"""Validate whether the user input allows us to connect to the gateway. """Validate whether the user input allows us to connect to the gateway.
Data has the keys from _base_gw_schema() with values provided by the user. Data has the keys from base_schema() with values provided by the user.
""" """
websession = async_get_clientsession(hass, verify_ssl=False) websession = async_get_clientsession(hass, verify_ssl=False)
api = Smile( api = Smile(
@ -85,7 +84,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
discovery_info: ZeroconfServiceInfo | None = None discovery_info: ZeroconfServiceInfo | None = None
product: str | None = None product: str = "Unknown Smile"
_username: str = DEFAULT_USERNAME _username: str = DEFAULT_USERNAME
async def async_step_zeroconf( async def async_step_zeroconf(
@ -98,7 +97,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
unique_id = discovery_info.hostname.split(".")[0].split("-")[0] unique_id = discovery_info.hostname.split(".")[0].split("-")[0]
if config_entry := await self.async_set_unique_id(unique_id): if config_entry := await self.async_set_unique_id(unique_id):
try: try:
await validate_gw_input( await validate_input(
self.hass, self.hass,
{ {
CONF_HOST: discovery_info.host, CONF_HOST: discovery_info.host,
@ -119,7 +118,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
if DEFAULT_USERNAME not in unique_id: if DEFAULT_USERNAME not in unique_id:
self._username = STRETCH_USERNAME self._username = STRETCH_USERNAME
self.product = _product = _properties.get("product", None) self.product = _product = _properties.get("product", "Unknown Smile")
_version = _properties.get("version", "n/a") _version = _properties.get("version", "n/a")
_name = f"{ZEROCONF_MAP.get(_product, _product)} v{_version}" _name = f"{ZEROCONF_MAP.get(_product, _product)} v{_version}"
@ -137,7 +136,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
self.context.update( self.context.update(
{ {
"title_placeholders": {CONF_NAME: _name}, "title_placeholders": {CONF_NAME: _name},
"configuration_url": ( ATTR_CONFIGURATION_URL: (
f"http://{discovery_info.host}:{discovery_info.port}" f"http://{discovery_info.host}:{discovery_info.port}"
), ),
} }
@ -160,7 +159,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the initial step when using network/gateway setups.""" """Handle the initial step when using network/gateway setups."""
errors = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
if self.discovery_info: if self.discovery_info:
@ -169,7 +168,7 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
user_input[CONF_USERNAME] = self._username user_input[CONF_USERNAME] = self._username
try: try:
api = await validate_gw_input(self.hass, user_input) api = await validate_input(self.hass, user_input)
except ConnectionFailedError: except ConnectionFailedError:
errors[CONF_BASE] = "cannot_connect" errors[CONF_BASE] = "cannot_connect"
except InvalidAuthentication: except InvalidAuthentication:
@ -188,11 +187,10 @@ class PlugwiseConfigFlow(ConfigFlow, domain=DOMAIN):
) )
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
user_input[PW_TYPE] = API
return self.async_create_entry(title=api.smile_name, data=user_input) return self.async_create_entry(title=api.smile_name, data=user_input)
return self.async_show_form( return self.async_show_form(
step_id="user", step_id=SOURCE_USER,
data_schema=_base_gw_schema(self.discovery_info), data_schema=base_schema(self.discovery_info),
errors=errors, errors=errors,
) )

View File

@ -12,7 +12,7 @@ from plugwise.exceptions import (
) )
import pytest import pytest
from homeassistant.components.plugwise.const import API, DEFAULT_PORT, DOMAIN, PW_TYPE from homeassistant.components.plugwise.const import DEFAULT_PORT, DOMAIN
from homeassistant.components.zeroconf import ZeroconfServiceInfo from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.config_entries import SOURCE_USER, SOURCE_ZEROCONF from homeassistant.config_entries import SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.const import ( from homeassistant.const import (
@ -123,7 +123,6 @@ async def test_form(
CONF_PASSWORD: TEST_PASSWORD, CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT, CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME, CONF_USERNAME: TEST_USERNAME,
PW_TYPE: API,
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -168,7 +167,6 @@ async def test_zeroconf_flow(
CONF_PASSWORD: TEST_PASSWORD, CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT, CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME, CONF_USERNAME: TEST_USERNAME,
PW_TYPE: API,
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -204,7 +202,6 @@ async def test_zeroconf_flow_stretch(
CONF_PASSWORD: TEST_PASSWORD, CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT, CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME2, CONF_USERNAME: TEST_USERNAME2,
PW_TYPE: API,
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -308,7 +305,6 @@ async def test_flow_errors(
CONF_PASSWORD: TEST_PASSWORD, CONF_PASSWORD: TEST_PASSWORD,
CONF_PORT: DEFAULT_PORT, CONF_PORT: DEFAULT_PORT,
CONF_USERNAME: TEST_USERNAME, CONF_USERNAME: TEST_USERNAME,
PW_TYPE: API,
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1