From 6a89b3a135ebaff1f3c8f6dbc8e32974241b5e8f Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Mon, 16 Jan 2023 17:03:00 +0100 Subject: [PATCH] Small refactor to HomeWizard config flow (#86020) * Small refactor to HomeWizard config flow * Update homeassistant/components/homewizard/config_flow.py Co-authored-by: Duco Sebel <74970928+DCSBL@users.noreply.github.com> * Process review comments Co-authored-by: Duco Sebel <74970928+DCSBL@users.noreply.github.com> --- .../components/homewizard/config_flow.py | 189 +++++++----------- 1 file changed, 72 insertions(+), 117 deletions(-) diff --git a/homeassistant/components/homewizard/config_flow.py b/homeassistant/components/homewizard/config_flow.py index dc314e051ce..dc9b6b61640 100644 --- a/homeassistant/components/homewizard/config_flow.py +++ b/homeassistant/components/homewizard/config_flow.py @@ -3,15 +3,15 @@ from __future__ import annotations from collections.abc import Mapping import logging -from typing import Any, cast +from typing import Any, NamedTuple from homewizard_energy import HomeWizardEnergy from homewizard_energy.errors import DisabledError, RequestError, UnsupportedError from homewizard_energy.models import Device from voluptuous import Required, Schema -from homeassistant import config_entries from homeassistant.components import zeroconf +from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.const import CONF_IP_ADDRESS from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.exceptions import HomeAssistantError @@ -28,72 +28,58 @@ from .const import ( _LOGGER = logging.getLogger(__name__) -class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): +class DiscoveryData(NamedTuple): + """User metadata.""" + + ip: str + product_name: str + product_type: str + serial: str + + +class HomeWizardConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for P1 meter.""" VERSION = 1 - def __init__(self) -> None: - """Initialize the HomeWizard config flow.""" - self.config: dict[str, str | int] = {} - self.entry: config_entries.ConfigEntry | None = None + discovery: DiscoveryData + entry: ConfigEntry | None async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle a flow initiated by the user.""" + errors: dict[str, str] | None = None + if user_input is not None: + try: + device_info = await self._async_try_connect(user_input[CONF_IP_ADDRESS]) + except RecoverableError as ex: + _LOGGER.error(ex) + errors = {"base": ex.error_code} + else: + await self.async_set_unique_id( + f"{device_info.product_type}_{device_info.serial}" + ) + self._abort_if_unique_id_configured(updates=user_input) + return self.async_create_entry( + title=f"{device_info.product_name} ({device_info.serial})", + data=user_input, + ) - _LOGGER.debug("config_flow async_step_user") - - data_schema = Schema( - { - Required(CONF_IP_ADDRESS): str, - } - ) - - if user_input is None: - return self.async_show_form( - step_id="user", - data_schema=data_schema, - errors=None, - ) - - # Fetch device information - try: - device_info = await self._async_try_connect(user_input[CONF_IP_ADDRESS]) - except RecoverableError as ex: - _LOGGER.error(ex) - return self.async_show_form( - step_id="user", - data_schema=data_schema, - errors={"base": ex.error_code}, - ) - - # Sets unique ID and aborts if it is already exists - await self._async_set_and_check_unique_id( - { - CONF_IP_ADDRESS: user_input[CONF_IP_ADDRESS], - CONF_PRODUCT_TYPE: device_info.product_type, - CONF_SERIAL: device_info.serial, - } - ) - - data: dict[str, str] = {CONF_IP_ADDRESS: user_input[CONF_IP_ADDRESS]} - - # Add entry - return self.async_create_entry( - title=f"{device_info.product_name} ({device_info.serial})", - data=data, + return self.async_show_form( + step_id="user", + data_schema=Schema( + { + Required(CONF_IP_ADDRESS): str, + } + ), + errors=errors, ) async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> FlowResult: """Handle zeroconf discovery.""" - - _LOGGER.debug("config_flow async_step_zeroconf") - - # Validate doscovery entry if ( CONF_API_ENABLED not in discovery_info.properties or CONF_PATH not in discovery_info.properties @@ -106,70 +92,56 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if (discovery_info.properties[CONF_PATH]) != "/api/v1": return self.async_abort(reason="unsupported_api_version") - # Sets unique ID and aborts if it is already exists - await self._async_set_and_check_unique_id( - { - CONF_IP_ADDRESS: discovery_info.host, - CONF_PRODUCT_TYPE: discovery_info.properties[CONF_PRODUCT_TYPE], - CONF_SERIAL: discovery_info.properties[CONF_SERIAL], - } + self.discovery = DiscoveryData( + ip=discovery_info.host, + product_type=discovery_info.properties[CONF_PRODUCT_TYPE], + product_name=discovery_info.properties[CONF_PRODUCT_NAME], + serial=discovery_info.properties[CONF_SERIAL], + ) + + await self.async_set_unique_id( + f"{self.discovery.product_type}_{self.discovery.serial}" + ) + self._abort_if_unique_id_configured( + updates={CONF_IP_ADDRESS: discovery_info.host} ) - # Pass parameters - self.config = { - CONF_API_ENABLED: discovery_info.properties[CONF_API_ENABLED], - CONF_IP_ADDRESS: discovery_info.host, - CONF_PRODUCT_TYPE: discovery_info.properties[CONF_PRODUCT_TYPE], - CONF_PRODUCT_NAME: discovery_info.properties[CONF_PRODUCT_NAME], - CONF_SERIAL: discovery_info.properties[CONF_SERIAL], - } return await self.async_step_discovery_confirm() async def async_step_discovery_confirm( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Confirm discovery.""" + errors: dict[str, str] | None = None if user_input is not None: - try: - await self._async_try_connect(str(self.config[CONF_IP_ADDRESS])) + await self._async_try_connect(self.discovery.ip) except RecoverableError as ex: _LOGGER.error(ex) - return self.async_show_form( - step_id="discovery_confirm", - errors={"base": ex.error_code}, - description_placeholders={ - CONF_PRODUCT_TYPE: cast(str, self.config[CONF_PRODUCT_TYPE]), - CONF_SERIAL: cast(str, self.config[CONF_SERIAL]), - CONF_IP_ADDRESS: cast(str, self.config[CONF_IP_ADDRESS]), - }, + errors = {"base": ex.error_code} + else: + return self.async_create_entry( + title=f"{self.discovery.product_name} ({self.discovery.serial})", + data={CONF_IP_ADDRESS: self.discovery.ip}, ) - return self.async_create_entry( - title=f"{self.config[CONF_PRODUCT_NAME]} ({self.config[CONF_SERIAL]})", - data={ - CONF_IP_ADDRESS: self.config[CONF_IP_ADDRESS], - }, - ) - self._set_confirm_only() - self.context["title_placeholders"] = { - "name": f"{self.config[CONF_PRODUCT_NAME]} ({self.config[CONF_SERIAL]})" + "name": f"{self.discovery.product_name} ({self.discovery.serial})" } return self.async_show_form( step_id="discovery_confirm", description_placeholders={ - CONF_PRODUCT_TYPE: cast(str, self.config[CONF_PRODUCT_TYPE]), - CONF_SERIAL: cast(str, self.config[CONF_SERIAL]), - CONF_IP_ADDRESS: cast(str, self.config[CONF_IP_ADDRESS]), + CONF_PRODUCT_TYPE: self.discovery.product_type, + CONF_SERIAL: self.discovery.serial, + CONF_IP_ADDRESS: self.discovery.ip, }, + errors=errors, ) async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: """Handle re-auth if API was disabled.""" - self.entry = self.hass.config_entries.async_get_entry(self.context["entry_id"]) return await self.async_step_reauth_confirm() @@ -177,36 +149,31 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Confirm reauth dialog.""" - + errors: dict[str, str] | None = None if user_input is not None: assert self.entry is not None - try: await self._async_try_connect(self.entry.data[CONF_IP_ADDRESS]) except RecoverableError as ex: _LOGGER.error(ex) - return self.async_show_form( - step_id="reauth_confirm", - errors={"base": ex.error_code}, - ) - - await self.hass.config_entries.async_reload(self.entry.entry_id) - return self.async_abort(reason="reauth_successful") + errors = {"base": ex.error_code} + else: + await self.hass.config_entries.async_reload(self.entry.entry_id) + return self.async_abort(reason="reauth_successful") return self.async_show_form( step_id="reauth_confirm", + errors=errors, ) @staticmethod async def _async_try_connect(ip_address: str) -> Device: - """Try to connect.""" + """Try to connect. - _LOGGER.debug("config_flow _async_try_connect") - - # Make connection with device - # This is to test the connection and to get info for unique_id + Make connection with device to test the connection + and to get info for unique_id. + """ energy_api = HomeWizardEnergy(ip_address) - try: return await energy_api.device() @@ -231,18 +198,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): finally: await energy_api.close() # type: ignore[no-untyped-call] - async def _async_set_and_check_unique_id(self, entry_info: dict[str, Any]) -> None: - """Validate if entry exists.""" - - _LOGGER.debug("config_flow _async_set_and_check_unique_id") - - await self.async_set_unique_id( - f"{entry_info[CONF_PRODUCT_TYPE]}_{entry_info[CONF_SERIAL]}" - ) - self._abort_if_unique_id_configured( - updates={CONF_IP_ADDRESS: entry_info[CONF_IP_ADDRESS]} - ) - class RecoverableError(HomeAssistantError): """Raised when a connection has been failed but can be retried."""