Minor improvements of MQTT typing (#52578)

* Minor improvements of MQTT typing

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-06 14:38:48 +02:00 committed by GitHub
parent dc72c6c606
commit 4d32e1ed01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 32 deletions

View File

@ -12,7 +12,7 @@ from axis.streammanager import SIGNAL_PLAYING, STATE_STOPPED
from homeassistant.components import mqtt
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
from homeassistant.components.mqtt.models import Message
from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.const import (
CONF_HOST,
CONF_NAME,
@ -195,7 +195,7 @@ class AxisNetworkDevice:
)
@callback
def mqtt_message(self, message: Message) -> None:
def mqtt_message(self, message: ReceiveMessage) -> None:
"""Receive Axis MQTT message."""
self.disconnect_from_stream()

View File

@ -75,9 +75,11 @@ from .const import (
from .discovery import LAST_DISCOVERY
from .models import (
AsyncMessageCallbackType,
Message,
MessageCallbackType,
PublishMessage,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
@ -290,9 +292,9 @@ def async_publish_template(
AsyncDeprecatedMessageCallbackType = Callable[
[str, PublishPayloadType, int], Awaitable[None]
[str, ReceivePayloadType, int], Awaitable[None]
]
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
def wrap_msg_callback(
@ -308,7 +310,7 @@ def wrap_msg_callback(
if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback)
async def async_wrapper(msg: Message) -> None:
async def async_wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos
@ -318,7 +320,7 @@ def wrap_msg_callback(
else:
@wraps(msg_callback)
def wrapper(msg: Message) -> None:
def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
@ -676,7 +678,7 @@ class MQTT:
CONF_WILL_MESSAGE in self.conf
and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE]
):
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE])
else:
will_message = None
@ -853,7 +855,7 @@ class MQTT:
retain=birth_message.retain,
)
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE])
asyncio.run_coroutine_threadsafe(
publish_birth_message(birth_message), self.hass.loop
)
@ -900,7 +902,7 @@ class MQTT:
self.hass.async_run_hass_job(
subscription.job,
Message(
ReceiveMessage(
msg.topic,
payload,
msg.qos,
@ -1043,7 +1045,7 @@ async def websocket_subscribe(hass, connection, msg):
if not connection.user.is_admin:
raise Unauthorized
async def forward_messages(mqttmsg: Message):
async def forward_messages(mqttmsg: ReceiveMessage):
"""Forward events to websocket."""
connection.send_message(
websocket_api.event_message(
@ -1064,8 +1066,13 @@ async def websocket_subscribe(hass, connection, msg):
connection.send_message(websocket_api.result_message(msg["id"]))
ConnectionStatusCallback = Callable[[bool], None]
@callback
def async_subscribe_connection_status(hass, connection_status_callback):
def async_subscribe_connection_status(
hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
) -> Callable[[], None]:
"""Subscribe to MQTT connection changes."""
connection_status_callback_job = HassJob(connection_status_callback)
@ -1092,6 +1099,6 @@ def async_subscribe_connection_status(hass, connection_status_callback):
return unsubscribe
def is_connected(hass):
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
return hass.data[DATA_MQTT].connected

View File

@ -38,7 +38,7 @@ from .discovery import (
clear_discovery_hash,
set_discovery_hash,
)
from .models import Message
from .models import ReceiveMessage
from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .util import valid_subscribe_topic
@ -221,7 +221,7 @@ class MqttAttributes(Entity):
@callback
@log_messages(self.hass, self.entity_id)
def attributes_message_received(msg: Message) -> None:
def attributes_message_received(msg: ReceiveMessage) -> None:
try:
payload = msg.payload
if attr_tpl is not None:
@ -318,7 +318,7 @@ class MqttAvailability(Entity):
@callback
@log_messages(self.hass, self.entity_id)
def availability_message_received(msg: Message) -> None:
def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message."""
topic = msg.topic
if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:

View File

@ -7,19 +7,30 @@ from typing import Awaitable, Callable, Union
import attr
PublishPayloadType = Union[str, bytes, int, float, None]
ReceivePayloadType = Union[str, bytes]
@attr.s(slots=True, frozen=True)
class Message:
class PublishMessage:
"""MQTT Message."""
topic: str = attr.ib()
payload: PublishPayloadType = attr.ib()
qos: int = attr.ib()
retain: bool = attr.ib()
subscribed_topic: str | None = attr.ib(default=None)
timestamp: dt.datetime | None = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]]
MessageCallbackType = Callable[[Message], None]
@attr.s(slots=True, frozen=True)
class ReceiveMessage:
"""MQTT Message."""
topic: str = attr.ib()
payload: ReceivePayloadType = attr.ib()
qos: int = attr.ib()
retain: bool = attr.ib()
subscribed_topic: str = attr.ib(default=None)
timestamp: dt.datetime = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]]
MessageCallbackType = Callable[[ReceiveMessage], None]

View File

@ -66,7 +66,7 @@ async def async_subscribe_topics(
hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None,
topics: dict[str, Any],
):
) -> dict[str, EntitySubscription]:
"""(Re)Subscribe to a set of MQTT topics.
State is kept in sub_state and a dictionary mapping from the subscription
@ -106,6 +106,8 @@ async def async_subscribe_topics(
@bind_hass
async def async_unsubscribe_topics(hass: HomeAssistant, sub_state: dict):
async def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return await async_subscribe_topics(hass, sub_state, {})

View File

@ -15,8 +15,8 @@ import voluptuous as vol
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
from homeassistant.components.mqtt.models import (
Message as MQTTMessage,
PublishPayloadType,
ReceiveMessage as MQTTReceiveMessage,
ReceivePayloadType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
@ -188,12 +188,12 @@ async def _get_gateway(
mqtt.async_publish(topic, payload, qos, retain)
def sub_callback(
topic: str, sub_cb: Callable[[str, PublishPayloadType, int], None], qos: int
topic: str, sub_cb: Callable[[str, ReceivePayloadType, int], None], qos: int
) -> None:
"""Call MQTT subscribe function."""
@callback
def internal_callback(msg: MQTTMessage) -> None:
def internal_callback(msg: MQTTReceiveMessage) -> None:
"""Call callback."""
sub_cb(msg.topic, msg.payload, msg.qos)

View File

@ -34,7 +34,7 @@ from homeassistant.components.device_automation import ( # noqa: F401
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.mqtt.models import Message
from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.config import async_process_component_config
from homeassistant.const import (
DEVICE_DEFAULT_NAME,
@ -353,7 +353,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
"""Fire the MQTT message."""
if isinstance(payload, str):
payload = payload.encode("utf-8")
msg = Message(topic, payload, qos, retain)
msg = ReceiveMessage(topic, payload, qos, retain)
hass.data["mqtt"]._mqtt_handle_message(msg)

View File

@ -1,6 +1,6 @@
"""Test config flow."""
from homeassistant import config_entries
from homeassistant.components.mqtt.models import Message
from homeassistant.components.mqtt.models import ReceiveMessage
from tests.common import MockConfigEntry
@ -19,7 +19,9 @@ async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock):
async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
"""Check MQTT flow aborts if discovery topic is invalid."""
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##")
discovery_info = ReceiveMessage(
"", "", 0, False, subscribed_topic="custom_prefix/##"
)
result = await hass.config_entries.flow.async_init(
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
)
@ -29,7 +31,9 @@ async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
async def test_mqtt_setup(hass, mqtt_mock) -> None:
"""Test we can finish a config flow through MQTT with custom prefix."""
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#")
discovery_info = ReceiveMessage(
"", "", 0, False, subscribed_topic="custom_prefix/123/#"
)
result = await hass.config_entries.flow.async_init(
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
)