Improve config flow type hints in xiaomi_aqara (#125316)

This commit is contained in:
epenet 2024-09-06 15:16:32 +02:00 committed by GitHub
parent f5f8c44ca6
commit 66c6cd2a10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import logging import logging
from socket import gaierror from socket import gaierror
from typing import TYPE_CHECKING, Any from typing import Any
import voluptuous as vol import voluptuous as vol
from xiaomi_gateway import MULTICAST_PORT, XiaomiGateway, XiaomiGatewayDiscovery from xiaomi_gateway import MULTICAST_PORT, XiaomiGateway, XiaomiGatewayDiscovery
@ -50,13 +50,14 @@ class XiaomiAqaraFlowHandler(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
selected_gateway: XiaomiGateway
gateways: dict[str, XiaomiGateway]
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize.""" """Initialize."""
self.host: str | None = None self.host: str | None = None
self.interface = DEFAULT_INTERFACE self.interface = DEFAULT_INTERFACE
self.sid: str | None = None self.sid: str | None = None
self.gateways: dict[str, XiaomiGateway] | None = None
self.selected_gateway: XiaomiGateway | None = None
@callback @callback
def async_show_form_step_user(self, errors): def async_show_form_step_user(self, errors):
@ -99,8 +100,6 @@ class XiaomiAqaraFlowHandler(ConfigFlow, domain=DOMAIN):
None, None,
) )
if TYPE_CHECKING:
assert self.selected_gateway
if self.selected_gateway.connection_error: if self.selected_gateway.connection_error:
errors[CONF_HOST] = "invalid_host" errors[CONF_HOST] = "invalid_host"
if self.selected_gateway.mac_error: if self.selected_gateway.mac_error:
@ -120,8 +119,6 @@ class XiaomiAqaraFlowHandler(ConfigFlow, domain=DOMAIN):
self.gateways = xiaomi.gateways self.gateways = xiaomi.gateways
if TYPE_CHECKING:
assert self.gateways is not None
if len(self.gateways) == 1: if len(self.gateways) == 1:
self.selected_gateway = list(self.gateways.values())[0] self.selected_gateway = list(self.gateways.values())[0]
self.sid = self.selected_gateway.sid self.sid = self.selected_gateway.sid
@ -132,9 +129,11 @@ class XiaomiAqaraFlowHandler(ConfigFlow, domain=DOMAIN):
errors["base"] = "discovery_error" errors["base"] = "discovery_error"
return self.async_show_form_step_user(errors) return self.async_show_form_step_user(errors)
async def async_step_select(self, user_input=None): async def async_step_select(
self, user_input: dict[str, str] | None = None
) -> ConfigFlowResult:
"""Handle multiple aqara gateways found.""" """Handle multiple aqara gateways found."""
errors = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
ip_adress = user_input["select_ip"] ip_adress = user_input["select_ip"]
self.selected_gateway = self.gateways[ip_adress] self.selected_gateway = self.gateways[ip_adress]
@ -192,7 +191,9 @@ class XiaomiAqaraFlowHandler(ConfigFlow, domain=DOMAIN):
return await self.async_step_user() return await self.async_step_user()
async def async_step_settings(self, user_input=None): async def async_step_settings(
self, user_input: dict[str, str] | None = None
) -> ConfigFlowResult:
"""Specify settings and connect aqara gateway.""" """Specify settings and connect aqara gateway."""
errors = {} errors = {}
if user_input is not None: if user_input is not None: