diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 3a6fd068975..9883e7b6ec8 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -9,7 +9,7 @@ import logging from operator import attrgetter import ssl import time -from typing import Any, Callable, Union +from typing import Any, Awaitable, Callable, Union, cast import uuid import attr @@ -73,7 +73,12 @@ from .const import ( PROTOCOL_311, ) from .discovery import LAST_DISCOVERY -from .models import Message, MessageCallbackType, PublishPayloadType +from .models import ( + AsyncMessageCallbackType, + Message, + MessageCallbackType, + PublishPayloadType, +) from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -284,26 +289,36 @@ def async_publish_template( hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data)) -def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType: +AsyncDeprecatedMessageCallbackType = Callable[ + [str, PublishPayloadType, int], Awaitable[None] +] +DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None] + + +def wrap_msg_callback( + msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType, +) -> AsyncMessageCallbackType | MessageCallbackType: """Wrap an MQTT message callback to support deprecated signature.""" # Check for partials to properly determine if coroutine function check_func = msg_callback while isinstance(check_func, partial): check_func = check_func.func - wrapper_func = None + wrapper_func: AsyncMessageCallbackType | MessageCallbackType if asyncio.iscoroutinefunction(check_func): @wraps(msg_callback) - async def async_wrapper(msg: Any) -> None: + async def async_wrapper(msg: Message) -> None: """Call with deprecated signature.""" - await msg_callback(msg.topic, msg.payload, msg.qos) + await cast(AsyncDeprecatedMessageCallbackType, msg_callback)( + msg.topic, msg.payload, msg.qos + ) wrapper_func = async_wrapper else: @wraps(msg_callback) - def wrapper(msg: Any) -> None: + def wrapper(msg: Message) -> None: """Call with deprecated signature.""" msg_callback(msg.topic, msg.payload, msg.qos) @@ -315,7 +330,10 @@ def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType: async def async_subscribe( hass: HomeAssistant, topic: str, - msg_callback: MessageCallbackType, + msg_callback: AsyncMessageCallbackType + | MessageCallbackType + | DeprecatedMessageCallbackType + | AsyncDeprecatedMessageCallbackType, qos: int = DEFAULT_QOS, encoding: str | None = "utf-8", ): @@ -334,12 +352,15 @@ async def async_subscribe( wrapped_msg_callback = msg_callback # If we have 3 parameters with no default value, wrap the callback if non_default == 3: + module = inspect.getmodule(msg_callback) _LOGGER.warning( "Signature of MQTT msg_callback '%s.%s' is deprecated", - inspect.getmodule(msg_callback).__name__, + module.__name__ if module else "", msg_callback.__name__, ) - wrapped_msg_callback = wrap_msg_callback(msg_callback) + wrapped_msg_callback = wrap_msg_callback( + cast(DeprecatedMessageCallbackType, msg_callback) + ) async_remove = await hass.data[DATA_MQTT].async_subscribe( topic, @@ -378,16 +399,12 @@ def subscribe( async def _async_setup_discovery( hass: HomeAssistant, conf: ConfigType, config_entry -) -> bool: +) -> None: """Try to start the discovery of MQTT devices. This method is a coroutine. """ - success: bool = await discovery.async_start( - hass, conf[CONF_DISCOVERY_PREFIX], config_entry - ) - - return success + await discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config_entry) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -539,7 +556,7 @@ class Subscription: matcher: Any = attr.ib() job: HassJob = attr.ib() qos: int = attr.ib(default=0) - encoding: str = attr.ib(default="utf-8") + encoding: str | None = attr.ib(default="utf-8") class MQTT: @@ -566,7 +583,7 @@ class MQTT: self._mqttc: mqtt.Client = None self._paho_lock = asyncio.Lock() - self._pending_operations = {} + self._pending_operations: dict[str, asyncio.Event] = {} if self.hass.state == CoreState.running: self._ha_started.set() @@ -688,12 +705,12 @@ class MQTT: _raise_on_error(msg_info.rc) await self._wait_for_mid(msg_info.mid) - async def async_connect(self) -> str: + async def async_connect(self) -> None: """Connect to the host. Does not process messages yet.""" # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt - result: int = None + result: int | None = None try: result = await self.hass.async_add_executor_job( self._mqttc.connect, @@ -770,7 +787,7 @@ class MQTT: This method is a coroutine. """ async with self._paho_lock: - result: int = None + result: int | None = None result, mid = await self.hass.async_add_executor_job( self._mqttc.unsubscribe, topic ) @@ -781,7 +798,7 @@ class MQTT: async def _async_perform_subscription(self, topic: str, qos: int) -> None: """Perform a paho-mqtt subscription.""" async with self._paho_lock: - result: int = None + result: int | None = None result, mid = await self.hass.async_add_executor_job( self._mqttc.subscribe, topic, qos ) @@ -952,12 +969,12 @@ class MQTT: ) -def _raise_on_error(result_code: int) -> None: +def _raise_on_error(result_code: int | None) -> None: """Raise error if error result.""" # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt - if result_code != 0: + if result_code is not None and result_code != 0: raise HomeAssistantError( f"Error talking to MQTT: {mqtt.error_string(result_code)}" ) @@ -1014,13 +1031,13 @@ async def websocket_remove_device(hass, connection, msg): ) -@websocket_api.async_response @websocket_api.websocket_command( { vol.Required("type"): "mqtt/subscribe", vol.Required("topic"): valid_subscribe_topic, } ) +@websocket_api.async_response async def websocket_subscribe(hass, connection, msg): """Subscribe to a MQTT topic.""" if not connection.user.is_admin: diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index e24abc27028..66dea3e3aa0 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -226,6 +226,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity): def available(self) -> bool: """Return true if the device is available and value has not expired.""" expire_after = self._config.get(CONF_EXPIRE_AFTER) - return MqttAvailability.available.fget(self) and ( + # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 + return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined] expire_after is None or not self._expired ) diff --git a/homeassistant/components/mqtt/debug_info.py b/homeassistant/components/mqtt/debug_info.py index d00d65c2451..57cb88e65e3 100644 --- a/homeassistant/components/mqtt/debug_info.py +++ b/homeassistant/components/mqtt/debug_info.py @@ -1,7 +1,7 @@ """Helper to handle a set of topics to subscribe to.""" from collections import deque from functools import wraps -from typing import Any +from typing import Any, Callable from homeassistant.core import HomeAssistant @@ -12,7 +12,9 @@ DATA_MQTT_DEBUG_INFO = "mqtt_debug_info" STORED_MESSAGES = 10 -def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType: +def log_messages( + hass: HomeAssistant, entity_id: str +) -> Callable[[MessageCallbackType], MessageCallbackType]: """Wrap an MQTT message callback to support message logging.""" def _log_message(msg): @@ -24,7 +26,7 @@ def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType: if msg not in messages: messages.append(msg) - def _decorator(msg_callback: MessageCallbackType): + def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType: @wraps(msg_callback) def wrapper(msg: Any) -> None: """Log message.""" diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index d9413b80c06..89246406de3 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -119,15 +119,15 @@ class Trigger: """Device trigger settings.""" device_id: str = attr.ib() - discovery_data: dict = attr.ib() + discovery_data: dict | None = attr.ib() hass: HomeAssistant = attr.ib() - payload: str = attr.ib() - qos: int = attr.ib() - remove_signal: Callable[[], None] = attr.ib() + payload: str | None = attr.ib() + qos: int | None = attr.ib() + remove_signal: Callable[[], None] | None = attr.ib() subtype: str = attr.ib() - topic: str = attr.ib() + topic: str | None = attr.ib() type: str = attr.ib() - value_template: str = attr.ib() + value_template: str | None = attr.ib() trigger_instances: list[TriggerInstance] = attr.ib(factory=list) async def add_trigger(self, action, automation_info): @@ -289,7 +289,7 @@ async def async_device_removed(hass: HomeAssistant, device_id: str): async def async_get_triggers(hass: HomeAssistant, device_id: str) -> list[dict]: """List device triggers for MQTT devices.""" - triggers = [] + triggers: list[dict] = [] if DEVICE_TRIGGERS not in hass.data: return triggers diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index d35065e30a8..e0d1d0eb4dd 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -83,7 +83,7 @@ class MQTTConfig(dict): async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic, config_entry=None -) -> bool: +) -> None: """Start MQTT Discovery.""" mqtt_integrations = {} @@ -298,10 +298,8 @@ async def async_start( # noqa: C901 0, ) - return True - -async def async_stop(hass: HomeAssistant) -> bool: +async def async_stop(hass: HomeAssistant) -> None: """Stop MQTT Discovery.""" if DISCOVERY_UNSUBSCRIBE in hass.data: for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]: diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index b0c8b573b37..9e45d6d4f27 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -4,6 +4,7 @@ from __future__ import annotations from abc import abstractmethod import json import logging +from typing import Callable import voluptuous as vol @@ -194,11 +195,11 @@ async def async_setup_entry_helper(hass, domain, async_setup, schema): class MqttAttributes(Entity): """Mixin used for platforms that support JSON attributes.""" - _attributes_extra_blocked = frozenset() + _attributes_extra_blocked: frozenset[str] = frozenset() def __init__(self, config: dict) -> None: """Initialize the JSON attributes mixin.""" - self._attributes = None + self._attributes: dict | None = None self._attributes_sub_state = None self._attributes_config = config @@ -225,7 +226,7 @@ class MqttAttributes(Entity): payload = msg.payload if attr_tpl is not None: payload = attr_tpl.async_render_with_possible_json_value(payload) - json_dict = json.loads(payload) + json_dict = json.loads(payload) if isinstance(payload, str) else None if isinstance(json_dict, dict): filtered_dict = { k: v @@ -272,7 +273,7 @@ class MqttAvailability(Entity): def __init__(self, config: dict) -> None: """Initialize the availability mixin.""" self._availability_sub_state = None - self._available = {} + self._available: dict = {} self._available_latest = False self._availability_setup_from_config(config) @@ -397,7 +398,7 @@ class MqttDiscoveryUpdate(Entity): """Initialize the discovery update mixin.""" self._discovery_data = discovery_data self._discovery_update = discovery_update - self._remove_signal = None + self._remove_signal: Callable | None = None self._removed_from_hass = False async def async_added_to_hass(self) -> None: diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index 7cdafeef98d..0c8c311d768 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime as dt -from typing import Callable, Union +from typing import Awaitable, Callable, Union import attr @@ -21,4 +21,5 @@ class Message: timestamp: dt.datetime | None = attr.ib(default=None) +AsyncMessageCallbackType = Callable[[Message], Awaitable[None]] MessageCallbackType = Callable[[Message], None] diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 777a15b639a..0c234fbbbea 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -242,6 +242,7 @@ class MqttSensor(MqttEntity, SensorEntity): def available(self) -> bool: """Return true if the device is available and value has not expired.""" expire_after = self._config.get(CONF_EXPIRE_AFTER) - return MqttAvailability.available.fget(self) and ( + # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 + return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined] expire_after is None or not self._expired ) diff --git a/mypy.ini b/mypy.ini index eca6f699022..5ed805a1b59 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1345,9 +1345,6 @@ ignore_errors = true [mypy-homeassistant.components.motion_blinds.*] ignore_errors = true -[mypy-homeassistant.components.mqtt.*] -ignore_errors = true - [mypy-homeassistant.components.mullvad.*] ignore_errors = true diff --git a/script/hassfest/mypy_config.py b/script/hassfest/mypy_config.py index b09fbbe98a9..8b9b15d35aa 100644 --- a/script/hassfest/mypy_config.py +++ b/script/hassfest/mypy_config.py @@ -125,7 +125,6 @@ IGNORED_MODULES: Final[list[str]] = [ "homeassistant.components.minecraft_server.*", "homeassistant.components.mobile_app.*", "homeassistant.components.motion_blinds.*", - "homeassistant.components.mqtt.*", "homeassistant.components.mullvad.*", "homeassistant.components.neato.*", "homeassistant.components.ness_alarm.*",