diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 81142fadb87..192de624f17 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable, Iterable +from collections.abc import Awaitable, Callable, Coroutine, Iterable from functools import lru_cache, partial, wraps import inspect from itertools import groupby @@ -15,6 +15,7 @@ import uuid import attr import certifi +from paho.mqtt.client import MQTTMessage from homeassistant.const import ( CONF_CLIENT_ID, @@ -246,7 +247,7 @@ class Subscription: topic: str = attr.ib() matcher: Any = attr.ib() - job: HassJob = attr.ib() + job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None] = attr.ib() qos: int = attr.ib(default=0) encoding: str | None = attr.ib(default="utf-8") @@ -444,7 +445,7 @@ class MQTT: async def async_subscribe( self, topic: str, - msg_callback: MessageCallbackType, + msg_callback: AsyncMessageCallbackType | MessageCallbackType, qos: int, encoding: str | None = None, ) -> Callable[[], None]: @@ -597,15 +598,15 @@ class MQTT: self.hass.add_job(self._mqtt_handle_message, msg) @lru_cache(2048) - def _matching_subscriptions(self, topic): - subscriptions = [] + def _matching_subscriptions(self, topic: str) -> list[Subscription]: + subscriptions: list[Subscription] = [] for subscription in self.subscriptions: if subscription.matcher(topic): subscriptions.append(subscription) return subscriptions @callback - def _mqtt_handle_message(self, msg) -> None: + def _mqtt_handle_message(self, msg: MQTTMessage) -> None: _LOGGER.debug( "Received message on %s%s: %s", msg.topic, diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index d5560f6954e..84bf704a262 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -2,7 +2,7 @@ from __future__ import annotations from ast import literal_eval -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine import datetime as dt from typing import Any, Union @@ -42,7 +42,7 @@ class ReceiveMessage: timestamp: dt.datetime = attr.ib(default=None) -AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]] +AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]] MessageCallbackType = Callable[[ReceiveMessage], None]