Enable basic type checking for mqtt (#52463)

* Enable basic type checking for mqtt

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-05 10:33:12 +02:00 committed by GitHub
parent 5321151799
commit 79ee112490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 68 additions and 51 deletions

View File

@ -9,7 +9,7 @@ import logging
from operator import attrgetter from operator import attrgetter
import ssl import ssl
import time import time
from typing import Any, Callable, Union from typing import Any, Awaitable, Callable, Union, cast
import uuid import uuid
import attr import attr
@ -73,7 +73,12 @@ from .const import (
PROTOCOL_311, PROTOCOL_311,
) )
from .discovery import LAST_DISCOVERY from .discovery import LAST_DISCOVERY
from .models import Message, MessageCallbackType, PublishPayloadType from .models import (
AsyncMessageCallbackType,
Message,
MessageCallbackType,
PublishPayloadType,
)
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -284,26 +289,36 @@ def async_publish_template(
hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data)) hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data))
def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType: AsyncDeprecatedMessageCallbackType = Callable[
[str, PublishPayloadType, int], Awaitable[None]
]
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None]
def wrap_msg_callback(
msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType,
) -> AsyncMessageCallbackType | MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature.""" """Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
check_func = msg_callback check_func = msg_callback
while isinstance(check_func, partial): while isinstance(check_func, partial):
check_func = check_func.func check_func = check_func.func
wrapper_func = None wrapper_func: AsyncMessageCallbackType | MessageCallbackType
if asyncio.iscoroutinefunction(check_func): if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback) @wraps(msg_callback)
async def async_wrapper(msg: Any) -> None: async def async_wrapper(msg: Message) -> None:
"""Call with deprecated signature.""" """Call with deprecated signature."""
await msg_callback(msg.topic, msg.payload, msg.qos) await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos
)
wrapper_func = async_wrapper wrapper_func = async_wrapper
else: else:
@wraps(msg_callback) @wraps(msg_callback)
def wrapper(msg: Any) -> None: def wrapper(msg: Message) -> None:
"""Call with deprecated signature.""" """Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos) msg_callback(msg.topic, msg.payload, msg.qos)
@ -315,7 +330,10 @@ def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType:
async def async_subscribe( async def async_subscribe(
hass: HomeAssistant, hass: HomeAssistant,
topic: str, topic: str,
msg_callback: MessageCallbackType, msg_callback: AsyncMessageCallbackType
| MessageCallbackType
| DeprecatedMessageCallbackType
| AsyncDeprecatedMessageCallbackType,
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = "utf-8", encoding: str | None = "utf-8",
): ):
@ -334,12 +352,15 @@ async def async_subscribe(
wrapped_msg_callback = msg_callback wrapped_msg_callback = msg_callback
# If we have 3 parameters with no default value, wrap the callback # If we have 3 parameters with no default value, wrap the callback
if non_default == 3: if non_default == 3:
module = inspect.getmodule(msg_callback)
_LOGGER.warning( _LOGGER.warning(
"Signature of MQTT msg_callback '%s.%s' is deprecated", "Signature of MQTT msg_callback '%s.%s' is deprecated",
inspect.getmodule(msg_callback).__name__, module.__name__ if module else "<unknown>",
msg_callback.__name__, msg_callback.__name__,
) )
wrapped_msg_callback = wrap_msg_callback(msg_callback) wrapped_msg_callback = wrap_msg_callback(
cast(DeprecatedMessageCallbackType, msg_callback)
)
async_remove = await hass.data[DATA_MQTT].async_subscribe( async_remove = await hass.data[DATA_MQTT].async_subscribe(
topic, topic,
@ -378,16 +399,12 @@ def subscribe(
async def _async_setup_discovery( async def _async_setup_discovery(
hass: HomeAssistant, conf: ConfigType, config_entry hass: HomeAssistant, conf: ConfigType, config_entry
) -> bool: ) -> None:
"""Try to start the discovery of MQTT devices. """Try to start the discovery of MQTT devices.
This method is a coroutine. This method is a coroutine.
""" """
success: bool = await discovery.async_start( await discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config_entry)
hass, conf[CONF_DISCOVERY_PREFIX], config_entry
)
return success
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -539,7 +556,7 @@ class Subscription:
matcher: Any = attr.ib() matcher: Any = attr.ib()
job: HassJob = attr.ib() job: HassJob = attr.ib()
qos: int = attr.ib(default=0) qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8") encoding: str | None = attr.ib(default="utf-8")
class MQTT: class MQTT:
@ -566,7 +583,7 @@ class MQTT:
self._mqttc: mqtt.Client = None self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock() self._paho_lock = asyncio.Lock()
self._pending_operations = {} self._pending_operations: dict[str, asyncio.Event] = {}
if self.hass.state == CoreState.running: if self.hass.state == CoreState.running:
self._ha_started.set() self._ha_started.set()
@ -688,12 +705,12 @@ class MQTT:
_raise_on_error(msg_info.rc) _raise_on_error(msg_info.rc)
await self._wait_for_mid(msg_info.mid) await self._wait_for_mid(msg_info.mid)
async def async_connect(self) -> str: async def async_connect(self) -> None:
"""Connect to the host. Does not process messages yet.""" """Connect to the host. Does not process messages yet."""
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
result: int = None result: int | None = None
try: try:
result = await self.hass.async_add_executor_job( result = await self.hass.async_add_executor_job(
self._mqttc.connect, self._mqttc.connect,
@ -770,7 +787,7 @@ class MQTT:
This method is a coroutine. This method is a coroutine.
""" """
async with self._paho_lock: async with self._paho_lock:
result: int = None result: int | None = None
result, mid = await self.hass.async_add_executor_job( result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topic self._mqttc.unsubscribe, topic
) )
@ -781,7 +798,7 @@ class MQTT:
async def _async_perform_subscription(self, topic: str, qos: int) -> None: async def _async_perform_subscription(self, topic: str, qos: int) -> None:
"""Perform a paho-mqtt subscription.""" """Perform a paho-mqtt subscription."""
async with self._paho_lock: async with self._paho_lock:
result: int = None result: int | None = None
result, mid = await self.hass.async_add_executor_job( result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, topic, qos self._mqttc.subscribe, topic, qos
) )
@ -952,12 +969,12 @@ class MQTT:
) )
def _raise_on_error(result_code: int) -> None: def _raise_on_error(result_code: int | None) -> None:
"""Raise error if error result.""" """Raise error if error result."""
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
if result_code != 0: if result_code is not None and result_code != 0:
raise HomeAssistantError( raise HomeAssistantError(
f"Error talking to MQTT: {mqtt.error_string(result_code)}" f"Error talking to MQTT: {mqtt.error_string(result_code)}"
) )
@ -1014,13 +1031,13 @@ async def websocket_remove_device(hass, connection, msg):
) )
@websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "mqtt/subscribe", vol.Required("type"): "mqtt/subscribe",
vol.Required("topic"): valid_subscribe_topic, vol.Required("topic"): valid_subscribe_topic,
} }
) )
@websocket_api.async_response
async def websocket_subscribe(hass, connection, msg): async def websocket_subscribe(hass, connection, msg):
"""Subscribe to a MQTT topic.""" """Subscribe to a MQTT topic."""
if not connection.user.is_admin: if not connection.user.is_admin:

View File

@ -226,6 +226,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity):
def available(self) -> bool: def available(self) -> bool:
"""Return true if the device is available and value has not expired.""" """Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER) expire_after = self._config.get(CONF_EXPIRE_AFTER)
return MqttAvailability.available.fget(self) and ( # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired expire_after is None or not self._expired
) )

View File

@ -1,7 +1,7 @@
"""Helper to handle a set of topics to subscribe to.""" """Helper to handle a set of topics to subscribe to."""
from collections import deque from collections import deque
from functools import wraps from functools import wraps
from typing import Any from typing import Any, Callable
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -12,7 +12,9 @@ DATA_MQTT_DEBUG_INFO = "mqtt_debug_info"
STORED_MESSAGES = 10 STORED_MESSAGES = 10
def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType: def log_messages(
hass: HomeAssistant, entity_id: str
) -> Callable[[MessageCallbackType], MessageCallbackType]:
"""Wrap an MQTT message callback to support message logging.""" """Wrap an MQTT message callback to support message logging."""
def _log_message(msg): def _log_message(msg):
@ -24,7 +26,7 @@ def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType:
if msg not in messages: if msg not in messages:
messages.append(msg) messages.append(msg)
def _decorator(msg_callback: MessageCallbackType): def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
@wraps(msg_callback) @wraps(msg_callback)
def wrapper(msg: Any) -> None: def wrapper(msg: Any) -> None:
"""Log message.""" """Log message."""

View File

@ -119,15 +119,15 @@ class Trigger:
"""Device trigger settings.""" """Device trigger settings."""
device_id: str = attr.ib() device_id: str = attr.ib()
discovery_data: dict = attr.ib() discovery_data: dict | None = attr.ib()
hass: HomeAssistant = attr.ib() hass: HomeAssistant = attr.ib()
payload: str = attr.ib() payload: str | None = attr.ib()
qos: int = attr.ib() qos: int | None = attr.ib()
remove_signal: Callable[[], None] = attr.ib() remove_signal: Callable[[], None] | None = attr.ib()
subtype: str = attr.ib() subtype: str = attr.ib()
topic: str = attr.ib() topic: str | None = attr.ib()
type: str = attr.ib() type: str = attr.ib()
value_template: str = attr.ib() value_template: str | None = attr.ib()
trigger_instances: list[TriggerInstance] = attr.ib(factory=list) trigger_instances: list[TriggerInstance] = attr.ib(factory=list)
async def add_trigger(self, action, automation_info): async def add_trigger(self, action, automation_info):
@ -289,7 +289,7 @@ async def async_device_removed(hass: HomeAssistant, device_id: str):
async def async_get_triggers(hass: HomeAssistant, device_id: str) -> list[dict]: async def async_get_triggers(hass: HomeAssistant, device_id: str) -> list[dict]:
"""List device triggers for MQTT devices.""" """List device triggers for MQTT devices."""
triggers = [] triggers: list[dict] = []
if DEVICE_TRIGGERS not in hass.data: if DEVICE_TRIGGERS not in hass.data:
return triggers return triggers

View File

@ -83,7 +83,7 @@ class MQTTConfig(dict):
async def async_start( # noqa: C901 async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None hass: HomeAssistant, discovery_topic, config_entry=None
) -> bool: ) -> None:
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_integrations = {} mqtt_integrations = {}
@ -298,10 +298,8 @@ async def async_start( # noqa: C901
0, 0,
) )
return True
async def async_stop(hass: HomeAssistant) -> None:
async def async_stop(hass: HomeAssistant) -> bool:
"""Stop MQTT Discovery.""" """Stop MQTT Discovery."""
if DISCOVERY_UNSUBSCRIBE in hass.data: if DISCOVERY_UNSUBSCRIBE in hass.data:
for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]: for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import json import json
import logging import logging
from typing import Callable
import voluptuous as vol import voluptuous as vol
@ -194,11 +195,11 @@ async def async_setup_entry_helper(hass, domain, async_setup, schema):
class MqttAttributes(Entity): class MqttAttributes(Entity):
"""Mixin used for platforms that support JSON attributes.""" """Mixin used for platforms that support JSON attributes."""
_attributes_extra_blocked = frozenset() _attributes_extra_blocked: frozenset[str] = frozenset()
def __init__(self, config: dict) -> None: def __init__(self, config: dict) -> None:
"""Initialize the JSON attributes mixin.""" """Initialize the JSON attributes mixin."""
self._attributes = None self._attributes: dict | None = None
self._attributes_sub_state = None self._attributes_sub_state = None
self._attributes_config = config self._attributes_config = config
@ -225,7 +226,7 @@ class MqttAttributes(Entity):
payload = msg.payload payload = msg.payload
if attr_tpl is not None: if attr_tpl is not None:
payload = attr_tpl.async_render_with_possible_json_value(payload) payload = attr_tpl.async_render_with_possible_json_value(payload)
json_dict = json.loads(payload) json_dict = json.loads(payload) if isinstance(payload, str) else None
if isinstance(json_dict, dict): if isinstance(json_dict, dict):
filtered_dict = { filtered_dict = {
k: v k: v
@ -272,7 +273,7 @@ class MqttAvailability(Entity):
def __init__(self, config: dict) -> None: def __init__(self, config: dict) -> None:
"""Initialize the availability mixin.""" """Initialize the availability mixin."""
self._availability_sub_state = None self._availability_sub_state = None
self._available = {} self._available: dict = {}
self._available_latest = False self._available_latest = False
self._availability_setup_from_config(config) self._availability_setup_from_config(config)
@ -397,7 +398,7 @@ class MqttDiscoveryUpdate(Entity):
"""Initialize the discovery update mixin.""" """Initialize the discovery update mixin."""
self._discovery_data = discovery_data self._discovery_data = discovery_data
self._discovery_update = discovery_update self._discovery_update = discovery_update
self._remove_signal = None self._remove_signal: Callable | None = None
self._removed_from_hass = False self._removed_from_hass = False
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import datetime as dt import datetime as dt
from typing import Callable, Union from typing import Awaitable, Callable, Union
import attr import attr
@ -21,4 +21,5 @@ class Message:
timestamp: dt.datetime | None = attr.ib(default=None) timestamp: dt.datetime | None = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]]
MessageCallbackType = Callable[[Message], None] MessageCallbackType = Callable[[Message], None]

View File

@ -242,6 +242,7 @@ class MqttSensor(MqttEntity, SensorEntity):
def available(self) -> bool: def available(self) -> bool:
"""Return true if the device is available and value has not expired.""" """Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER) expire_after = self._config.get(CONF_EXPIRE_AFTER)
return MqttAvailability.available.fget(self) and ( # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired expire_after is None or not self._expired
) )

View File

@ -1345,9 +1345,6 @@ ignore_errors = true
[mypy-homeassistant.components.motion_blinds.*] [mypy-homeassistant.components.motion_blinds.*]
ignore_errors = true ignore_errors = true
[mypy-homeassistant.components.mqtt.*]
ignore_errors = true
[mypy-homeassistant.components.mullvad.*] [mypy-homeassistant.components.mullvad.*]
ignore_errors = true ignore_errors = true

View File

@ -125,7 +125,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.minecraft_server.*", "homeassistant.components.minecraft_server.*",
"homeassistant.components.mobile_app.*", "homeassistant.components.mobile_app.*",
"homeassistant.components.motion_blinds.*", "homeassistant.components.motion_blinds.*",
"homeassistant.components.mqtt.*",
"homeassistant.components.mullvad.*", "homeassistant.components.mullvad.*",
"homeassistant.components.neato.*", "homeassistant.components.neato.*",
"homeassistant.components.ness_alarm.*", "homeassistant.components.ness_alarm.*",