mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Small refactor to HomeWizard config flow (#86020)
* Small refactor to HomeWizard config flow * Update homeassistant/components/homewizard/config_flow.py Co-authored-by: Duco Sebel <74970928+DCSBL@users.noreply.github.com> * Process review comments Co-authored-by: Duco Sebel <74970928+DCSBL@users.noreply.github.com>
This commit is contained in:
parent
9205020fa4
commit
6a89b3a135
@ -3,15 +3,15 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from homewizard_energy import HomeWizardEnergy
|
||||
from homewizard_energy.errors import DisabledError, RequestError, UnsupportedError
|
||||
from homewizard_energy.models import Device
|
||||
from voluptuous import Required, Schema
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import zeroconf
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.const import CONF_IP_ADDRESS
|
||||
from homeassistant.data_entry_flow import AbortFlow, FlowResult
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@ -28,72 +28,58 @@ from .const import (
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
class DiscoveryData(NamedTuple):
|
||||
"""User metadata."""
|
||||
|
||||
ip: str
|
||||
product_name: str
|
||||
product_type: str
|
||||
serial: str
|
||||
|
||||
|
||||
class HomeWizardConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for P1 meter."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the HomeWizard config flow."""
|
||||
self.config: dict[str, str | int] = {}
|
||||
self.entry: config_entries.ConfigEntry | None = None
|
||||
discovery: DiscoveryData
|
||||
entry: ConfigEntry | None
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle a flow initiated by the user."""
|
||||
|
||||
_LOGGER.debug("config_flow async_step_user")
|
||||
|
||||
data_schema = Schema(
|
||||
{
|
||||
Required(CONF_IP_ADDRESS): str,
|
||||
}
|
||||
)
|
||||
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=data_schema,
|
||||
errors=None,
|
||||
)
|
||||
|
||||
# Fetch device information
|
||||
errors: dict[str, str] | None = None
|
||||
if user_input is not None:
|
||||
try:
|
||||
device_info = await self._async_try_connect(user_input[CONF_IP_ADDRESS])
|
||||
except RecoverableError as ex:
|
||||
_LOGGER.error(ex)
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=data_schema,
|
||||
errors={"base": ex.error_code},
|
||||
errors = {"base": ex.error_code}
|
||||
else:
|
||||
await self.async_set_unique_id(
|
||||
f"{device_info.product_type}_{device_info.serial}"
|
||||
)
|
||||
|
||||
# Sets unique ID and aborts if it is already exists
|
||||
await self._async_set_and_check_unique_id(
|
||||
{
|
||||
CONF_IP_ADDRESS: user_input[CONF_IP_ADDRESS],
|
||||
CONF_PRODUCT_TYPE: device_info.product_type,
|
||||
CONF_SERIAL: device_info.serial,
|
||||
}
|
||||
)
|
||||
|
||||
data: dict[str, str] = {CONF_IP_ADDRESS: user_input[CONF_IP_ADDRESS]}
|
||||
|
||||
# Add entry
|
||||
self._abort_if_unique_id_configured(updates=user_input)
|
||||
return self.async_create_entry(
|
||||
title=f"{device_info.product_name} ({device_info.serial})",
|
||||
data=data,
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=Schema(
|
||||
{
|
||||
Required(CONF_IP_ADDRESS): str,
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_zeroconf(
|
||||
self, discovery_info: zeroconf.ZeroconfServiceInfo
|
||||
) -> FlowResult:
|
||||
"""Handle zeroconf discovery."""
|
||||
|
||||
_LOGGER.debug("config_flow async_step_zeroconf")
|
||||
|
||||
# Validate doscovery entry
|
||||
if (
|
||||
CONF_API_ENABLED not in discovery_info.properties
|
||||
or CONF_PATH not in discovery_info.properties
|
||||
@ -106,70 +92,56 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
if (discovery_info.properties[CONF_PATH]) != "/api/v1":
|
||||
return self.async_abort(reason="unsupported_api_version")
|
||||
|
||||
# Sets unique ID and aborts if it is already exists
|
||||
await self._async_set_and_check_unique_id(
|
||||
{
|
||||
CONF_IP_ADDRESS: discovery_info.host,
|
||||
CONF_PRODUCT_TYPE: discovery_info.properties[CONF_PRODUCT_TYPE],
|
||||
CONF_SERIAL: discovery_info.properties[CONF_SERIAL],
|
||||
}
|
||||
self.discovery = DiscoveryData(
|
||||
ip=discovery_info.host,
|
||||
product_type=discovery_info.properties[CONF_PRODUCT_TYPE],
|
||||
product_name=discovery_info.properties[CONF_PRODUCT_NAME],
|
||||
serial=discovery_info.properties[CONF_SERIAL],
|
||||
)
|
||||
|
||||
await self.async_set_unique_id(
|
||||
f"{self.discovery.product_type}_{self.discovery.serial}"
|
||||
)
|
||||
self._abort_if_unique_id_configured(
|
||||
updates={CONF_IP_ADDRESS: discovery_info.host}
|
||||
)
|
||||
|
||||
# Pass parameters
|
||||
self.config = {
|
||||
CONF_API_ENABLED: discovery_info.properties[CONF_API_ENABLED],
|
||||
CONF_IP_ADDRESS: discovery_info.host,
|
||||
CONF_PRODUCT_TYPE: discovery_info.properties[CONF_PRODUCT_TYPE],
|
||||
CONF_PRODUCT_NAME: discovery_info.properties[CONF_PRODUCT_NAME],
|
||||
CONF_SERIAL: discovery_info.properties[CONF_SERIAL],
|
||||
}
|
||||
return await self.async_step_discovery_confirm()
|
||||
|
||||
async def async_step_discovery_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Confirm discovery."""
|
||||
errors: dict[str, str] | None = None
|
||||
if user_input is not None:
|
||||
|
||||
try:
|
||||
await self._async_try_connect(str(self.config[CONF_IP_ADDRESS]))
|
||||
await self._async_try_connect(self.discovery.ip)
|
||||
except RecoverableError as ex:
|
||||
_LOGGER.error(ex)
|
||||
return self.async_show_form(
|
||||
step_id="discovery_confirm",
|
||||
errors={"base": ex.error_code},
|
||||
description_placeholders={
|
||||
CONF_PRODUCT_TYPE: cast(str, self.config[CONF_PRODUCT_TYPE]),
|
||||
CONF_SERIAL: cast(str, self.config[CONF_SERIAL]),
|
||||
CONF_IP_ADDRESS: cast(str, self.config[CONF_IP_ADDRESS]),
|
||||
},
|
||||
)
|
||||
|
||||
errors = {"base": ex.error_code}
|
||||
else:
|
||||
return self.async_create_entry(
|
||||
title=f"{self.config[CONF_PRODUCT_NAME]} ({self.config[CONF_SERIAL]})",
|
||||
data={
|
||||
CONF_IP_ADDRESS: self.config[CONF_IP_ADDRESS],
|
||||
},
|
||||
title=f"{self.discovery.product_name} ({self.discovery.serial})",
|
||||
data={CONF_IP_ADDRESS: self.discovery.ip},
|
||||
)
|
||||
|
||||
self._set_confirm_only()
|
||||
|
||||
self.context["title_placeholders"] = {
|
||||
"name": f"{self.config[CONF_PRODUCT_NAME]} ({self.config[CONF_SERIAL]})"
|
||||
"name": f"{self.discovery.product_name} ({self.discovery.serial})"
|
||||
}
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="discovery_confirm",
|
||||
description_placeholders={
|
||||
CONF_PRODUCT_TYPE: cast(str, self.config[CONF_PRODUCT_TYPE]),
|
||||
CONF_SERIAL: cast(str, self.config[CONF_SERIAL]),
|
||||
CONF_IP_ADDRESS: cast(str, self.config[CONF_IP_ADDRESS]),
|
||||
CONF_PRODUCT_TYPE: self.discovery.product_type,
|
||||
CONF_SERIAL: self.discovery.serial,
|
||||
CONF_IP_ADDRESS: self.discovery.ip,
|
||||
},
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
|
||||
"""Handle re-auth if API was disabled."""
|
||||
|
||||
self.entry = self.hass.config_entries.async_get_entry(self.context["entry_id"])
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
@ -177,36 +149,31 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Confirm reauth dialog."""
|
||||
|
||||
errors: dict[str, str] | None = None
|
||||
if user_input is not None:
|
||||
assert self.entry is not None
|
||||
|
||||
try:
|
||||
await self._async_try_connect(self.entry.data[CONF_IP_ADDRESS])
|
||||
except RecoverableError as ex:
|
||||
_LOGGER.error(ex)
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
errors={"base": ex.error_code},
|
||||
)
|
||||
|
||||
errors = {"base": ex.error_code}
|
||||
else:
|
||||
await self.hass.config_entries.async_reload(self.entry.entry_id)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _async_try_connect(ip_address: str) -> Device:
|
||||
"""Try to connect."""
|
||||
"""Try to connect.
|
||||
|
||||
_LOGGER.debug("config_flow _async_try_connect")
|
||||
|
||||
# Make connection with device
|
||||
# This is to test the connection and to get info for unique_id
|
||||
Make connection with device to test the connection
|
||||
and to get info for unique_id.
|
||||
"""
|
||||
energy_api = HomeWizardEnergy(ip_address)
|
||||
|
||||
try:
|
||||
return await energy_api.device()
|
||||
|
||||
@ -231,18 +198,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
finally:
|
||||
await energy_api.close() # type: ignore[no-untyped-call]
|
||||
|
||||
async def _async_set_and_check_unique_id(self, entry_info: dict[str, Any]) -> None:
|
||||
"""Validate if entry exists."""
|
||||
|
||||
_LOGGER.debug("config_flow _async_set_and_check_unique_id")
|
||||
|
||||
await self.async_set_unique_id(
|
||||
f"{entry_info[CONF_PRODUCT_TYPE]}_{entry_info[CONF_SERIAL]}"
|
||||
)
|
||||
self._abort_if_unique_id_configured(
|
||||
updates={CONF_IP_ADDRESS: entry_info[CONF_IP_ADDRESS]}
|
||||
)
|
||||
|
||||
|
||||
class RecoverableError(HomeAssistantError):
|
||||
"""Raised when a connection has been failed but can be retried."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user