Adjust async_step_mqtt signature for strict typing (#59761)

* Add MqttServiceInfo

* Adjust async_step_mqtt signature

* Adjust async_step_mqtt signature

* Adjust components

Co-authored-by: epenet <epenet@users.noreply.github.com>
This commit is contained in:
epenet 2021-11-16 13:30:38 +01:00 committed by GitHub
parent f1d75f0dd7
commit 4387bbfb94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 14 deletions

View File

@ -1,11 +1,13 @@
"""Support for MQTT discovery.""" """Support for MQTT discovery."""
import asyncio import asyncio
from collections import deque from collections import deque
import datetime as dt
import functools import functools
import json import json
import logging import logging
import re import re
import time import time
from typing import TypedDict
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -27,6 +29,7 @@ from .const import (
CONF_TOPIC, CONF_TOPIC,
DOMAIN, DOMAIN,
) )
from .models import ReceivePayloadType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -86,6 +89,17 @@ class MQTTConfig(dict):
"""Dummy class to allow adding attributes.""" """Dummy class to allow adding attributes."""
class MqttServiceInfo(TypedDict):
"""Prepared info from mqtt entries."""
topic: str
payload: ReceivePayloadType
qos: int
retain: bool
subscribed_topic: str
timestamp: dt.datetime
async def async_start( # noqa: C901 async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None hass: HomeAssistant, discovery_topic, config_entry=None
) -> None: ) -> None:
@ -288,14 +302,14 @@ async def async_start( # noqa: C901
if key not in hass.data[INTEGRATION_UNSUBSCRIBE]: if key not in hass.data[INTEGRATION_UNSUBSCRIBE]:
return return
data = { data = MqttServiceInfo(
"topic": msg.topic, topic=msg.topic,
"payload": msg.payload, payload=msg.payload,
"qos": msg.qos, qos=msg.qos,
"retain": msg.retain, retain=msg.retain,
"subscribed_topic": msg.subscribed_topic, subscribed_topic=msg.subscribed_topic,
"timestamp": msg.timestamp, timestamp=msg.timestamp,
} )
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
integration, context={"source": DOMAIN}, data=data integration, context={"source": DOMAIN}, data=data
) )

View File

@ -6,9 +6,8 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.mqtt import valid_subscribe_topic from homeassistant.components.mqtt import discovery as mqtt, valid_subscribe_topic
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.typing import DiscoveryInfoType
from .const import CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX, DOMAIN from .const import CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX, DOMAIN
@ -22,7 +21,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Initialize flow.""" """Initialize flow."""
self._prefix = DEFAULT_PREFIX self._prefix = DEFAULT_PREFIX
async def async_step_mqtt(self, discovery_info: DiscoveryInfoType) -> FlowResult: async def async_step_mqtt(self, discovery_info: mqtt.MqttServiceInfo) -> FlowResult:
"""Handle a flow initialized by MQTT discovery.""" """Handle a flow initialized by MQTT discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")

View File

@ -33,6 +33,7 @@ import homeassistant.util.uuid as uuid_util
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.components.dhcp import DhcpServiceInfo from homeassistant.components.dhcp import DhcpServiceInfo
from homeassistant.components.mqtt.discovery import MqttServiceInfo
from homeassistant.components.zeroconf import ZeroconfServiceInfo from homeassistant.components.zeroconf import ZeroconfServiceInfo
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -1361,10 +1362,10 @@ class ConfigFlow(data_entry_flow.FlowHandler):
return await self.async_step_discovery(cast(dict, discovery_info)) return await self.async_step_discovery(cast(dict, discovery_info))
async def async_step_mqtt( async def async_step_mqtt(
self, discovery_info: DiscoveryInfoType self, discovery_info: MqttServiceInfo
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by MQTT discovery.""" """Handle a flow initialized by MQTT discovery."""
return await self.async_step_discovery(discovery_info) return await self.async_step_discovery(cast(dict, discovery_info))
async def async_step_ssdp( async def async_step_ssdp(
self, discovery_info: DiscoveryInfoType self, discovery_info: DiscoveryInfoType

View File

@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, Union
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import dhcp, zeroconf from homeassistant.components import dhcp, zeroconf
from homeassistant.components.mqtt import discovery as mqtt
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.typing import UNDEFINED, DiscoveryInfoType, UndefinedType from homeassistant.helpers.typing import UNDEFINED, DiscoveryInfoType, UndefinedType
@ -102,6 +103,15 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_step_mqtt(self, discovery_info: mqtt.MqttServiceInfo) -> FlowResult:
"""Handle a flow initialized by mqtt discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
await self.async_set_unique_id(self._domain)
return await self.async_step_confirm()
async def async_step_zeroconf( async def async_step_zeroconf(
self, discovery_info: zeroconf.ZeroconfServiceInfo self, discovery_info: zeroconf.ZeroconfServiceInfo
) -> FlowResult: ) -> FlowResult:
@ -114,7 +124,6 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow):
return await self.async_step_confirm() return await self.async_step_confirm()
async_step_ssdp = async_step_discovery async_step_ssdp = async_step_discovery
async_step_mqtt = async_step_discovery
async def async_step_import(self, _: dict[str, Any] | None) -> FlowResult: async def async_step_import(self, _: dict[str, Any] | None) -> FlowResult:
"""Handle a flow initialized by import.""" """Handle a flow initialized by import."""