Add typing to Roomba config flow (#114624)

This commit is contained in:
Joost Lekkerkerker 2024-04-02 20:21:55 +02:00 committed by GitHub
parent f85511255c
commit 7cb01f75ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,8 +4,9 @@ from __future__ import annotations
import asyncio import asyncio
from functools import partial from functools import partial
from typing import Any
from roombapy import RoombaFactory from roombapy import RoombaFactory, RoombaInfo
from roombapy.discovery import RoombaDiscovery from roombapy.discovery import RoombaDiscovery
from roombapy.getpassword import RoombaPassword from roombapy.getpassword import RoombaPassword
import voluptuous as vol import voluptuous as vol
@ -15,7 +16,7 @@ from homeassistant.config_entries import (
ConfigEntry, ConfigEntry,
ConfigFlow, ConfigFlow,
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, OptionsFlowWithConfigEntry,
) )
from homeassistant.const import CONF_DELAY, CONF_HOST, CONF_NAME, CONF_PASSWORD from homeassistant.const import CONF_DELAY, CONF_HOST, CONF_NAME, CONF_PASSWORD
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -43,7 +44,7 @@ AUTH_HELP_URL_KEY = "auth_help_url"
AUTH_HELP_URL_VALUE = "https://www.home-assistant.io/integrations/roomba/#manually-retrieving-your-credentials" AUTH_HELP_URL_VALUE = "https://www.home-assistant.io/integrations/roomba/#manually-retrieving-your-credentials"
async def validate_input(hass: HomeAssistant, data): async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
Data has the keys from DATA_SCHEMA with values provided by the user. Data has the keys from DATA_SCHEMA with values provided by the user.
@ -75,20 +76,21 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
def __init__(self): name: str | None = None
blid: str | None = None
host: str | None = None
def __init__(self) -> None:
"""Initialize the roomba flow.""" """Initialize the roomba flow."""
self.discovered_robots = {} self.discovered_robots: dict[str, RoombaInfo] = {}
self.name = None
self.blid = None
self.host = None
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(
config_entry: ConfigEntry, config_entry: ConfigEntry,
) -> OptionsFlowHandler: ) -> RoombaOptionsFlowHandler:
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
return OptionsFlowHandler(config_entry) return RoombaOptionsFlowHandler(config_entry)
async def async_step_zeroconf( async def async_step_zeroconf(
self, discovery_info: zeroconf.ZeroconfServiceInfo self, discovery_info: zeroconf.ZeroconfServiceInfo
@ -135,8 +137,9 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
self.context["title_placeholders"] = {"host": self.host, "name": self.blid} self.context["title_placeholders"] = {"host": self.host, "name": self.blid}
return await self.async_step_user() return await self.async_step_user()
async def _async_start_link(self): async def _async_start_link(self) -> ConfigFlowResult:
"""Start linking.""" """Start linking."""
assert self.host
device = self.discovered_robots[self.host] device = self.discovered_robots[self.host]
self.blid = device.blid self.blid = device.blid
self.name = device.robot_name self.name = device.robot_name
@ -144,7 +147,9 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
return await self.async_step_link() return await self.async_step_link()
async def async_step_user(self, user_input=None): async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle a flow start.""" """Handle a flow start."""
# Check if user chooses manual entry # Check if user chooses manual entry
if user_input is not None and not user_input.get(CONF_HOST): if user_input is not None and not user_input.get(CONF_HOST):
@ -181,12 +186,7 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
if not self.discovered_robots: if not self.discovered_robots:
return await self.async_step_manual() return await self.async_step_manual()
return self.async_show_form( hosts: dict[str | None, str] = {
step_id="user",
data_schema=vol.Schema(
{
vol.Optional("host"): vol.In(
{
**{ **{
device.ip: f"{device.robot_name} ({device.ip})" device.ip: f"{device.robot_name} ({device.ip})"
for device in devices for device in devices
@ -194,12 +194,15 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
}, },
None: "Manually add a Roomba or Braava", None: "Manually add a Roomba or Braava",
} }
)
} return self.async_show_form(
), step_id="user",
data_schema=vol.Schema({vol.Optional("host"): vol.In(hosts)}),
) )
async def async_step_manual(self, user_input=None): async def async_step_manual(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle manual device setup.""" """Handle manual device setup."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -224,7 +227,9 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
return await self.async_step_link() return await self.async_step_link()
async def async_step_link(self, user_input=None): async def async_step_link(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Attempt to link with the Roomba. """Attempt to link with the Roomba.
Given a configured host, will ask the user to press the home and target buttons Given a configured host, will ask the user to press the home and target buttons
@ -235,7 +240,7 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
step_id="link", step_id="link",
description_placeholders={CONF_NAME: self.name or self.blid}, description_placeholders={CONF_NAME: self.name or self.blid},
) )
assert self.host
roomba_pw = RoombaPassword(self.host) roomba_pw = RoombaPassword(self.host)
try: try:
@ -260,10 +265,12 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="cannot_connect") return self.async_abort(reason="cannot_connect")
self.name = info[CONF_NAME] self.name = info[CONF_NAME]
assert self.name
return self.async_create_entry(title=self.name, data=config) return self.async_create_entry(title=self.name, data=config)
async def async_step_link_manual(self, user_input=None): async def async_step_link_manual(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle manual linking.""" """Handle manual linking."""
errors = {} errors = {}
@ -278,8 +285,7 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
info = await validate_input(self.hass, config) info = await validate_input(self.hass, config)
except CannotConnect: except CannotConnect:
errors = {"base": "cannot_connect"} errors = {"base": "cannot_connect"}
else:
if not errors:
return self.async_create_entry(title=info[CONF_NAME], data=config) return self.async_create_entry(title=info[CONF_NAME], data=config)
return self.async_show_form( return self.async_show_form(
@ -290,14 +296,12 @@ class RoombaConfigFlow(ConfigFlow, domain=DOMAIN):
) )
class OptionsFlowHandler(OptionsFlow): class RoombaOptionsFlowHandler(OptionsFlowWithConfigEntry):
"""Handle options.""" """Handle options."""
def __init__(self, config_entry: ConfigEntry) -> None: async def async_step_init(
"""Initialize options flow.""" self, user_input: dict[str, Any] | None = None
self.config_entry = config_entry ) -> ConfigFlowResult:
async def async_step_init(self, user_input=None):
"""Manage the options.""" """Manage the options."""
if user_input is not None: if user_input is not None:
return self.async_create_entry(title="", data=user_input) return self.async_create_entry(title="", data=user_input)
@ -308,15 +312,11 @@ class OptionsFlowHandler(OptionsFlow):
{ {
vol.Optional( vol.Optional(
CONF_CONTINUOUS, CONF_CONTINUOUS,
default=self.config_entry.options.get( default=self.options.get(CONF_CONTINUOUS, DEFAULT_CONTINUOUS),
CONF_CONTINUOUS, DEFAULT_CONTINUOUS
),
): bool, ): bool,
vol.Optional( vol.Optional(
CONF_DELAY, CONF_DELAY,
default=self.config_entry.options.get( default=self.options.get(CONF_DELAY, DEFAULT_DELAY),
CONF_DELAY, DEFAULT_DELAY
),
): int, ): int,
} }
), ),
@ -324,7 +324,7 @@ class OptionsFlowHandler(OptionsFlow):
@callback @callback
def _async_get_roomba_discovery(): def _async_get_roomba_discovery() -> RoombaDiscovery:
"""Create a discovery object.""" """Create a discovery object."""
discovery = RoombaDiscovery() discovery = RoombaDiscovery()
discovery.amount_of_broadcasted_messages = MAX_NUM_DEVICES_TO_DISCOVER discovery.amount_of_broadcasted_messages = MAX_NUM_DEVICES_TO_DISCOVER
@ -332,24 +332,28 @@ def _async_get_roomba_discovery():
@callback @callback
def _async_blid_from_hostname(hostname): def _async_blid_from_hostname(hostname: str) -> str:
"""Extract the blid from the hostname.""" """Extract the blid from the hostname."""
return hostname.split("-")[1].split(".")[0].upper() return hostname.split("-")[1].split(".")[0].upper()
async def _async_discover_roombas(hass, host): async def _async_discover_roombas(
discovered_hosts = set() hass: HomeAssistant, host: str | None = None
devices = [] ) -> list[RoombaInfo]:
discovered_hosts: set[str] = set()
devices: list[RoombaInfo] = []
discover_lock = hass.data.setdefault(ROOMBA_DISCOVERY_LOCK, asyncio.Lock()) discover_lock = hass.data.setdefault(ROOMBA_DISCOVERY_LOCK, asyncio.Lock())
discover_attempts = HOST_ATTEMPTS if host else ALL_ATTEMPTS discover_attempts = HOST_ATTEMPTS if host else ALL_ATTEMPTS
for attempt in range(discover_attempts + 1): for attempt in range(discover_attempts + 1):
async with discover_lock: async with discover_lock:
discovery = _async_get_roomba_discovery() discovery = _async_get_roomba_discovery()
discovered: set[RoombaInfo] = set()
try: try:
if host: if host:
device = await hass.async_add_executor_job(discovery.get, host) device = await hass.async_add_executor_job(discovery.get, host)
discovered = [device] if device else [] if device:
discovered.add(device)
else: else:
discovered = await hass.async_add_executor_job(discovery.get_all) discovered = await hass.async_add_executor_job(discovery.get_all)
except OSError: except OSError: