Improve bosch_shc typing (#86535)

This commit is contained in:
Marc Mueller 2023-01-24 17:06:00 +01:00 committed by GitHub
parent 6cad0c7984
commit 310d7718a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,10 @@
"""Config flow for Bosch Smart Home Controller integration.""" """Config flow for Bosch Smart Home Controller integration."""
from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from os import makedirs from os import makedirs
from typing import Any from typing import Any, cast
from boschshcpy import SHCRegisterClient, SHCSession from boschshcpy import SHCRegisterClient, SHCSession
from boschshcpy.exceptions import ( from boschshcpy.exceptions import (
@ -13,9 +15,10 @@ from boschshcpy.exceptions import (
) )
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries, core from homeassistant import config_entries
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_TOKEN from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_TOKEN
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from .const import ( from .const import (
@ -36,14 +39,19 @@ HOST_SCHEMA = vol.Schema(
) )
def write_tls_asset(hass: core.HomeAssistant, filename: str, asset: bytes) -> None: def write_tls_asset(hass: HomeAssistant, filename: str, asset: bytes) -> None:
"""Write the tls assets to disk.""" """Write the tls assets to disk."""
makedirs(hass.config.path(DOMAIN), exist_ok=True) makedirs(hass.config.path(DOMAIN), exist_ok=True)
with open(hass.config.path(DOMAIN, filename), "w", encoding="utf8") as file_handle: with open(hass.config.path(DOMAIN, filename), "w", encoding="utf8") as file_handle:
file_handle.write(asset.decode("utf-8")) file_handle.write(asset.decode("utf-8"))
def create_credentials_and_validate(hass, host, user_input, zeroconf_instance): def create_credentials_and_validate(
hass: HomeAssistant,
host: str,
user_input: dict[str, Any],
zeroconf_instance: zeroconf.HaZeroconf,
) -> dict[str, Any] | None:
"""Create and store credentials and validate session.""" """Create and store credentials and validate session."""
helper = SHCRegisterClient(host, user_input[CONF_PASSWORD]) helper = SHCRegisterClient(host, user_input[CONF_PASSWORD])
result = helper.register(host, "HomeAssistant") result = helper.register(host, "HomeAssistant")
@ -64,7 +72,9 @@ def create_credentials_and_validate(hass, host, user_input, zeroconf_instance):
return result return result
def get_info_from_host(hass, host, zeroconf_instance): def get_info_from_host(
hass: HomeAssistant, host: str, zeroconf_instance: zeroconf.HaZeroconf
) -> dict[str, str | None]:
"""Get information from host.""" """Get information from host."""
session = SHCSession( session = SHCSession(
host, host,
@ -81,15 +91,16 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Bosch SHC.""" """Handle a config flow for Bosch SHC."""
VERSION = 1 VERSION = 1
info = None info: dict[str, str | None]
host = None host: str | None = None
hostname = None
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Perform reauth upon an API authentication error.""" """Perform reauth upon an API authentication error."""
return await self.async_step_reauth_confirm() return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(self, user_input=None): async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Dialog that informs the user that reauth is required.""" """Dialog that informs the user that reauth is required."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -100,9 +111,11 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.info = await self._get_info(host) self.info = await self._get_info(host)
return await self.async_step_credentials() return await self.async_step_credentials()
async def async_step_user(self, user_input=None): async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step.""" """Handle the initial step."""
errors = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
host = user_input[CONF_HOST] host = user_input[CONF_HOST]
try: try:
@ -122,9 +135,11 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
step_id="user", data_schema=HOST_SCHEMA, errors=errors step_id="user", data_schema=HOST_SCHEMA, errors=errors
) )
async def async_step_credentials(self, user_input=None): async def async_step_credentials(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the credentials step.""" """Handle the credentials step."""
errors = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
zeroconf_instance = await zeroconf.async_get_instance(self.hass) zeroconf_instance = await zeroconf.async_get_instance(self.hass)
try: try:
@ -149,6 +164,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.exception("Unexpected exception") _LOGGER.exception("Unexpected exception")
errors["base"] = "unknown" errors["base"] = "unknown"
else: else:
assert result
entry_data = { entry_data = {
CONF_SSL_CERTIFICATE: self.hass.config.path(DOMAIN, CONF_SHC_CERT), CONF_SSL_CERTIFICATE: self.hass.config.path(DOMAIN, CONF_SHC_CERT),
CONF_SSL_KEY: self.hass.config.path(DOMAIN, CONF_SHC_KEY), CONF_SSL_KEY: self.hass.config.path(DOMAIN, CONF_SHC_KEY),
@ -166,7 +182,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="reauth_successful") return self.async_abort(reason="reauth_successful")
return self.async_create_entry( return self.async_create_entry(
title=self.info["title"], title=cast(str, self.info["title"]),
data=entry_data, data=entry_data,
) )
else: else:
@ -205,9 +221,11 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.context["title_placeholders"] = {"name": node_name} self.context["title_placeholders"] = {"name": node_name}
return await self.async_step_confirm_discovery() return await self.async_step_confirm_discovery()
async def async_step_confirm_discovery(self, user_input=None): async def async_step_confirm_discovery(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle discovery confirm.""" """Handle discovery confirm."""
errors = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
return await self.async_step_credentials() return await self.async_step_credentials()
@ -220,7 +238,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors=errors, errors=errors,
) )
async def _get_info(self, host): async def _get_info(self, host: str) -> dict[str, str | None]:
"""Get additional information.""" """Get additional information."""
zeroconf_instance = await zeroconf.async_get_instance(self.hass) zeroconf_instance = await zeroconf.async_get_instance(self.hass)