From 4d32e1ed01d6bd6cbf52e193acb33fb8e14d8462 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 6 Jul 2021 14:38:48 +0200 Subject: [PATCH] Minor improvements of MQTT typing (#52578) * Minor improvements of MQTT typing * Tweak --- homeassistant/components/axis/device.py | 4 +-- homeassistant/components/mqtt/__init__.py | 29 ++++++++++++------- homeassistant/components/mqtt/mixins.py | 6 ++-- homeassistant/components/mqtt/models.py | 21 ++++++++++---- homeassistant/components/mqtt/subscription.py | 6 ++-- homeassistant/components/mysensors/gateway.py | 8 ++--- tests/common.py | 4 +-- tests/components/tasmota/test_config_flow.py | 10 +++++-- 8 files changed, 56 insertions(+), 32 deletions(-) diff --git a/homeassistant/components/axis/device.py b/homeassistant/components/axis/device.py index f1a57eec33c..e4987c77139 100644 --- a/homeassistant/components/axis/device.py +++ b/homeassistant/components/axis/device.py @@ -12,7 +12,7 @@ from axis.streammanager import SIGNAL_PLAYING, STATE_STOPPED from homeassistant.components import mqtt from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN -from homeassistant.components.mqtt.models import Message +from homeassistant.components.mqtt.models import ReceiveMessage from homeassistant.const import ( CONF_HOST, CONF_NAME, @@ -195,7 +195,7 @@ class AxisNetworkDevice: ) @callback - def mqtt_message(self, message: Message) -> None: + def mqtt_message(self, message: ReceiveMessage) -> None: """Receive Axis MQTT message.""" self.disconnect_from_stream() diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 9883e7b6ec8..e524502dd8d 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -75,9 +75,11 @@ from .const import ( from .discovery import LAST_DISCOVERY from .models import ( AsyncMessageCallbackType, - Message, MessageCallbackType, + PublishMessage, PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, ) from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic @@ -290,9 +292,9 @@ def async_publish_template( AsyncDeprecatedMessageCallbackType = Callable[ - [str, PublishPayloadType, int], Awaitable[None] + [str, ReceivePayloadType, int], Awaitable[None] ] -DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None] +DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] def wrap_msg_callback( @@ -308,7 +310,7 @@ def wrap_msg_callback( if asyncio.iscoroutinefunction(check_func): @wraps(msg_callback) - async def async_wrapper(msg: Message) -> None: + async def async_wrapper(msg: ReceiveMessage) -> None: """Call with deprecated signature.""" await cast(AsyncDeprecatedMessageCallbackType, msg_callback)( msg.topic, msg.payload, msg.qos @@ -318,7 +320,7 @@ def wrap_msg_callback( else: @wraps(msg_callback) - def wrapper(msg: Message) -> None: + def wrapper(msg: ReceiveMessage) -> None: """Call with deprecated signature.""" msg_callback(msg.topic, msg.payload, msg.qos) @@ -676,7 +678,7 @@ class MQTT: CONF_WILL_MESSAGE in self.conf and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE] ): - will_message = Message(**self.conf[CONF_WILL_MESSAGE]) + will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE]) else: will_message = None @@ -853,7 +855,7 @@ class MQTT: retain=birth_message.retain, ) - birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE]) + birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE]) asyncio.run_coroutine_threadsafe( publish_birth_message(birth_message), self.hass.loop ) @@ -900,7 +902,7 @@ class MQTT: self.hass.async_run_hass_job( subscription.job, - Message( + ReceiveMessage( msg.topic, payload, msg.qos, @@ -1043,7 +1045,7 @@ async def websocket_subscribe(hass, connection, msg): if not connection.user.is_admin: raise Unauthorized - async def forward_messages(mqttmsg: Message): + async def forward_messages(mqttmsg: ReceiveMessage): """Forward events to websocket.""" connection.send_message( websocket_api.event_message( @@ -1064,8 +1066,13 @@ async def websocket_subscribe(hass, connection, msg): connection.send_message(websocket_api.result_message(msg["id"])) +ConnectionStatusCallback = Callable[[bool], None] + + @callback -def async_subscribe_connection_status(hass, connection_status_callback): +def async_subscribe_connection_status( + hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback +) -> Callable[[], None]: """Subscribe to MQTT connection changes.""" connection_status_callback_job = HassJob(connection_status_callback) @@ -1092,6 +1099,6 @@ def async_subscribe_connection_status(hass, connection_status_callback): return unsubscribe -def is_connected(hass): +def is_connected(hass: HomeAssistant) -> bool: """Return if MQTT client is connected.""" return hass.data[DATA_MQTT].connected diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 9e45d6d4f27..a40f06a3bb6 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -38,7 +38,7 @@ from .discovery import ( clear_discovery_hash, set_discovery_hash, ) -from .models import Message +from .models import ReceiveMessage from .subscription import async_subscribe_topics, async_unsubscribe_topics from .util import valid_subscribe_topic @@ -221,7 +221,7 @@ class MqttAttributes(Entity): @callback @log_messages(self.hass, self.entity_id) - def attributes_message_received(msg: Message) -> None: + def attributes_message_received(msg: ReceiveMessage) -> None: try: payload = msg.payload if attr_tpl is not None: @@ -318,7 +318,7 @@ class MqttAvailability(Entity): @callback @log_messages(self.hass, self.entity_id) - def availability_message_received(msg: Message) -> None: + def availability_message_received(msg: ReceiveMessage) -> None: """Handle a new received MQTT availability message.""" topic = msg.topic if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]: diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index 0c8c311d768..5c320ac0827 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -7,19 +7,30 @@ from typing import Awaitable, Callable, Union import attr PublishPayloadType = Union[str, bytes, int, float, None] +ReceivePayloadType = Union[str, bytes] @attr.s(slots=True, frozen=True) -class Message: +class PublishMessage: """MQTT Message.""" topic: str = attr.ib() payload: PublishPayloadType = attr.ib() qos: int = attr.ib() retain: bool = attr.ib() - subscribed_topic: str | None = attr.ib(default=None) - timestamp: dt.datetime | None = attr.ib(default=None) -AsyncMessageCallbackType = Callable[[Message], Awaitable[None]] -MessageCallbackType = Callable[[Message], None] +@attr.s(slots=True, frozen=True) +class ReceiveMessage: + """MQTT Message.""" + + topic: str = attr.ib() + payload: ReceivePayloadType = attr.ib() + qos: int = attr.ib() + retain: bool = attr.ib() + subscribed_topic: str = attr.ib(default=None) + timestamp: dt.datetime = attr.ib(default=None) + + +AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]] +MessageCallbackType = Callable[[ReceiveMessage], None] diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index 6c711600b2c..03259a37380 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -66,7 +66,7 @@ async def async_subscribe_topics( hass: HomeAssistant, new_state: dict[str, EntitySubscription] | None, topics: dict[str, Any], -): +) -> dict[str, EntitySubscription]: """(Re)Subscribe to a set of MQTT topics. State is kept in sub_state and a dictionary mapping from the subscription @@ -106,6 +106,8 @@ async def async_subscribe_topics( @bind_hass -async def async_unsubscribe_topics(hass: HomeAssistant, sub_state: dict): +async def async_unsubscribe_topics( + hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None +) -> dict[str, EntitySubscription]: """Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" return await async_subscribe_topics(hass, sub_state, {}) diff --git a/homeassistant/components/mysensors/gateway.py b/homeassistant/components/mysensors/gateway.py index f1e2cd0a4e1..f9410f66e8f 100644 --- a/homeassistant/components/mysensors/gateway.py +++ b/homeassistant/components/mysensors/gateway.py @@ -15,8 +15,8 @@ import voluptuous as vol from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN from homeassistant.components.mqtt.models import ( - Message as MQTTMessage, - PublishPayloadType, + ReceiveMessage as MQTTReceiveMessage, + ReceivePayloadType, ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import EVENT_HOMEASSISTANT_STOP @@ -188,12 +188,12 @@ async def _get_gateway( mqtt.async_publish(topic, payload, qos, retain) def sub_callback( - topic: str, sub_cb: Callable[[str, PublishPayloadType, int], None], qos: int + topic: str, sub_cb: Callable[[str, ReceivePayloadType, int], None], qos: int ) -> None: """Call MQTT subscribe function.""" @callback - def internal_callback(msg: MQTTMessage) -> None: + def internal_callback(msg: MQTTReceiveMessage) -> None: """Call callback.""" sub_cb(msg.topic, msg.payload, msg.qos) diff --git a/tests/common.py b/tests/common.py index 03b53294db0..5de58a08472 100644 --- a/tests/common.py +++ b/tests/common.py @@ -34,7 +34,7 @@ from homeassistant.components.device_automation import ( # noqa: F401 _async_get_device_automation_capabilities as async_get_device_automation_capabilities, _async_get_device_automations as async_get_device_automations, ) -from homeassistant.components.mqtt.models import Message +from homeassistant.components.mqtt.models import ReceiveMessage from homeassistant.config import async_process_component_config from homeassistant.const import ( DEVICE_DEFAULT_NAME, @@ -353,7 +353,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False): """Fire the MQTT message.""" if isinstance(payload, str): payload = payload.encode("utf-8") - msg = Message(topic, payload, qos, retain) + msg = ReceiveMessage(topic, payload, qos, retain) hass.data["mqtt"]._mqtt_handle_message(msg) diff --git a/tests/components/tasmota/test_config_flow.py b/tests/components/tasmota/test_config_flow.py index 9f199f0aa66..767d6b9cfcf 100644 --- a/tests/components/tasmota/test_config_flow.py +++ b/tests/components/tasmota/test_config_flow.py @@ -1,6 +1,6 @@ """Test config flow.""" from homeassistant import config_entries -from homeassistant.components.mqtt.models import Message +from homeassistant.components.mqtt.models import ReceiveMessage from tests.common import MockConfigEntry @@ -19,7 +19,9 @@ async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock): async def test_mqtt_abort_invalid_topic(hass, mqtt_mock): """Check MQTT flow aborts if discovery topic is invalid.""" - discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##") + discovery_info = ReceiveMessage( + "", "", 0, False, subscribed_topic="custom_prefix/##" + ) result = await hass.config_entries.flow.async_init( "tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info ) @@ -29,7 +31,9 @@ async def test_mqtt_abort_invalid_topic(hass, mqtt_mock): async def test_mqtt_setup(hass, mqtt_mock) -> None: """Test we can finish a config flow through MQTT with custom prefix.""" - discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#") + discovery_info = ReceiveMessage( + "", "", 0, False, subscribed_topic="custom_prefix/123/#" + ) result = await hass.config_entries.flow.async_init( "tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info )