From 4387bbfb949fba2c7dead76ab2d824dbe36da085 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Tue, 16 Nov 2021 13:30:38 +0100 Subject: [PATCH] Adjust async_step_mqtt signature for strict typing (#59761) * Add MqttServiceInfo * Adjust async_step_mqtt signature * Adjust async_step_mqtt signature * Adjust components Co-authored-by: epenet --- homeassistant/components/mqtt/discovery.py | 30 ++++++++++++++----- .../components/tasmota/config_flow.py | 5 ++-- homeassistant/config_entries.py | 5 ++-- homeassistant/helpers/config_entry_flow.py | 11 ++++++- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index d490374ed53..7aec80b2e9c 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -1,11 +1,13 @@ """Support for MQTT discovery.""" import asyncio from collections import deque +import datetime as dt import functools import json import logging import re import time +from typing import TypedDict from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.core import HomeAssistant @@ -27,6 +29,7 @@ from .const import ( CONF_TOPIC, DOMAIN, ) +from .models import ReceivePayloadType _LOGGER = logging.getLogger(__name__) @@ -86,6 +89,17 @@ class MQTTConfig(dict): """Dummy class to allow adding attributes.""" +class MqttServiceInfo(TypedDict): + """Prepared info from mqtt entries.""" + + topic: str + payload: ReceivePayloadType + qos: int + retain: bool + subscribed_topic: str + timestamp: dt.datetime + + async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic, config_entry=None ) -> None: @@ -288,14 +302,14 @@ async def async_start( # noqa: C901 if key not in hass.data[INTEGRATION_UNSUBSCRIBE]: return - data = { - "topic": msg.topic, - "payload": msg.payload, - "qos": msg.qos, - "retain": msg.retain, - "subscribed_topic": msg.subscribed_topic, - "timestamp": msg.timestamp, - } + data = MqttServiceInfo( + topic=msg.topic, + payload=msg.payload, + qos=msg.qos, + retain=msg.retain, + subscribed_topic=msg.subscribed_topic, + timestamp=msg.timestamp, + ) result = await hass.config_entries.flow.async_init( integration, context={"source": DOMAIN}, data=data ) diff --git a/homeassistant/components/tasmota/config_flow.py b/homeassistant/components/tasmota/config_flow.py index 435604b4bdd..9c22934678e 100644 --- a/homeassistant/components/tasmota/config_flow.py +++ b/homeassistant/components/tasmota/config_flow.py @@ -6,9 +6,8 @@ from typing import Any import voluptuous as vol from homeassistant import config_entries -from homeassistant.components.mqtt import valid_subscribe_topic +from homeassistant.components.mqtt import discovery as mqtt, valid_subscribe_topic from homeassistant.data_entry_flow import FlowResult -from homeassistant.helpers.typing import DiscoveryInfoType from .const import CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX, DOMAIN @@ -22,7 +21,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Initialize flow.""" self._prefix = DEFAULT_PREFIX - async def async_step_mqtt(self, discovery_info: DiscoveryInfoType) -> FlowResult: + async def async_step_mqtt(self, discovery_info: mqtt.MqttServiceInfo) -> FlowResult: """Handle a flow initialized by MQTT discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 85215a5962f..e1e4c103dc4 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -33,6 +33,7 @@ import homeassistant.util.uuid as uuid_util if TYPE_CHECKING: from homeassistant.components.dhcp import DhcpServiceInfo + from homeassistant.components.mqtt.discovery import MqttServiceInfo from homeassistant.components.zeroconf import ZeroconfServiceInfo _LOGGER = logging.getLogger(__name__) @@ -1361,10 +1362,10 @@ class ConfigFlow(data_entry_flow.FlowHandler): return await self.async_step_discovery(cast(dict, discovery_info)) async def async_step_mqtt( - self, discovery_info: DiscoveryInfoType + self, discovery_info: MqttServiceInfo ) -> data_entry_flow.FlowResult: """Handle a flow initialized by MQTT discovery.""" - return await self.async_step_discovery(discovery_info) + return await self.async_step_discovery(cast(dict, discovery_info)) async def async_step_ssdp( self, discovery_info: DiscoveryInfoType diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index 75e3d128435..939394a243a 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, Union from homeassistant import config_entries from homeassistant.components import dhcp, zeroconf +from homeassistant.components.mqtt import discovery as mqtt from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.typing import UNDEFINED, DiscoveryInfoType, UndefinedType @@ -102,6 +103,15 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow): return await self.async_step_confirm() + async def async_step_mqtt(self, discovery_info: mqtt.MqttServiceInfo) -> FlowResult: + """Handle a flow initialized by mqtt discovery.""" + if self._async_in_progress() or self._async_current_entries(): + return self.async_abort(reason="single_instance_allowed") + + await self.async_set_unique_id(self._domain) + + return await self.async_step_confirm() + async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> FlowResult: @@ -114,7 +124,6 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow): return await self.async_step_confirm() async_step_ssdp = async_step_discovery - async_step_mqtt = async_step_discovery async def async_step_import(self, _: dict[str, Any] | None) -> FlowResult: """Handle a flow initialized by import."""