Minor improvements of MQTT typing (#52578)

* Minor improvements of MQTT typing

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-06 14:38:48 +02:00 committed by GitHub
parent dc72c6c606
commit 4d32e1ed01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 32 deletions

View File

@ -12,7 +12,7 @@ from axis.streammanager import SIGNAL_PLAYING, STATE_STOPPED
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
from homeassistant.components.mqtt.models import Message from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_NAME, CONF_NAME,
@ -195,7 +195,7 @@ class AxisNetworkDevice:
) )
@callback @callback
def mqtt_message(self, message: Message) -> None: def mqtt_message(self, message: ReceiveMessage) -> None:
"""Receive Axis MQTT message.""" """Receive Axis MQTT message."""
self.disconnect_from_stream() self.disconnect_from_stream()

View File

@ -75,9 +75,11 @@ from .const import (
from .discovery import LAST_DISCOVERY from .discovery import LAST_DISCOVERY
from .models import ( from .models import (
AsyncMessageCallbackType, AsyncMessageCallbackType,
Message,
MessageCallbackType, MessageCallbackType,
PublishMessage,
PublishPayloadType, PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
) )
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
@ -290,9 +292,9 @@ def async_publish_template(
AsyncDeprecatedMessageCallbackType = Callable[ AsyncDeprecatedMessageCallbackType = Callable[
[str, PublishPayloadType, int], Awaitable[None] [str, ReceivePayloadType, int], Awaitable[None]
] ]
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None] DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
def wrap_msg_callback( def wrap_msg_callback(
@ -308,7 +310,7 @@ def wrap_msg_callback(
if asyncio.iscoroutinefunction(check_func): if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback) @wraps(msg_callback)
async def async_wrapper(msg: Message) -> None: async def async_wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature.""" """Call with deprecated signature."""
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)( await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos msg.topic, msg.payload, msg.qos
@ -318,7 +320,7 @@ def wrap_msg_callback(
else: else:
@wraps(msg_callback) @wraps(msg_callback)
def wrapper(msg: Message) -> None: def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature.""" """Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos) msg_callback(msg.topic, msg.payload, msg.qos)
@ -676,7 +678,7 @@ class MQTT:
CONF_WILL_MESSAGE in self.conf CONF_WILL_MESSAGE in self.conf
and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE] and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE]
): ):
will_message = Message(**self.conf[CONF_WILL_MESSAGE]) will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE])
else: else:
will_message = None will_message = None
@ -853,7 +855,7 @@ class MQTT:
retain=birth_message.retain, retain=birth_message.retain,
) )
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE]) birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE])
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
publish_birth_message(birth_message), self.hass.loop publish_birth_message(birth_message), self.hass.loop
) )
@ -900,7 +902,7 @@ class MQTT:
self.hass.async_run_hass_job( self.hass.async_run_hass_job(
subscription.job, subscription.job,
Message( ReceiveMessage(
msg.topic, msg.topic,
payload, payload,
msg.qos, msg.qos,
@ -1043,7 +1045,7 @@ async def websocket_subscribe(hass, connection, msg):
if not connection.user.is_admin: if not connection.user.is_admin:
raise Unauthorized raise Unauthorized
async def forward_messages(mqttmsg: Message): async def forward_messages(mqttmsg: ReceiveMessage):
"""Forward events to websocket.""" """Forward events to websocket."""
connection.send_message( connection.send_message(
websocket_api.event_message( websocket_api.event_message(
@ -1064,8 +1066,13 @@ async def websocket_subscribe(hass, connection, msg):
connection.send_message(websocket_api.result_message(msg["id"])) connection.send_message(websocket_api.result_message(msg["id"]))
ConnectionStatusCallback = Callable[[bool], None]
@callback @callback
def async_subscribe_connection_status(hass, connection_status_callback): def async_subscribe_connection_status(
hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
) -> Callable[[], None]:
"""Subscribe to MQTT connection changes.""" """Subscribe to MQTT connection changes."""
connection_status_callback_job = HassJob(connection_status_callback) connection_status_callback_job = HassJob(connection_status_callback)
@ -1092,6 +1099,6 @@ def async_subscribe_connection_status(hass, connection_status_callback):
return unsubscribe return unsubscribe
def is_connected(hass): def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected.""" """Return if MQTT client is connected."""
return hass.data[DATA_MQTT].connected return hass.data[DATA_MQTT].connected

View File

@ -38,7 +38,7 @@ from .discovery import (
clear_discovery_hash, clear_discovery_hash,
set_discovery_hash, set_discovery_hash,
) )
from .models import Message from .models import ReceiveMessage
from .subscription import async_subscribe_topics, async_unsubscribe_topics from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .util import valid_subscribe_topic from .util import valid_subscribe_topic
@ -221,7 +221,7 @@ class MqttAttributes(Entity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def attributes_message_received(msg: Message) -> None: def attributes_message_received(msg: ReceiveMessage) -> None:
try: try:
payload = msg.payload payload = msg.payload
if attr_tpl is not None: if attr_tpl is not None:
@ -318,7 +318,7 @@ class MqttAvailability(Entity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def availability_message_received(msg: Message) -> None: def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message.""" """Handle a new received MQTT availability message."""
topic = msg.topic topic = msg.topic
if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]: if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:

View File

@ -7,19 +7,30 @@ from typing import Awaitable, Callable, Union
import attr import attr
PublishPayloadType = Union[str, bytes, int, float, None] PublishPayloadType = Union[str, bytes, int, float, None]
ReceivePayloadType = Union[str, bytes]
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class Message: class PublishMessage:
"""MQTT Message.""" """MQTT Message."""
topic: str = attr.ib() topic: str = attr.ib()
payload: PublishPayloadType = attr.ib() payload: PublishPayloadType = attr.ib()
qos: int = attr.ib() qos: int = attr.ib()
retain: bool = attr.ib() retain: bool = attr.ib()
subscribed_topic: str | None = attr.ib(default=None)
timestamp: dt.datetime | None = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]] @attr.s(slots=True, frozen=True)
MessageCallbackType = Callable[[Message], None] class ReceiveMessage:
"""MQTT Message."""
topic: str = attr.ib()
payload: ReceivePayloadType = attr.ib()
qos: int = attr.ib()
retain: bool = attr.ib()
subscribed_topic: str = attr.ib(default=None)
timestamp: dt.datetime = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]]
MessageCallbackType = Callable[[ReceiveMessage], None]

View File

@ -66,7 +66,7 @@ async def async_subscribe_topics(
hass: HomeAssistant, hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None, new_state: dict[str, EntitySubscription] | None,
topics: dict[str, Any], topics: dict[str, Any],
): ) -> dict[str, EntitySubscription]:
"""(Re)Subscribe to a set of MQTT topics. """(Re)Subscribe to a set of MQTT topics.
State is kept in sub_state and a dictionary mapping from the subscription State is kept in sub_state and a dictionary mapping from the subscription
@ -106,6 +106,8 @@ async def async_subscribe_topics(
@bind_hass @bind_hass
async def async_unsubscribe_topics(hass: HomeAssistant, sub_state: dict): async def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" """Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return await async_subscribe_topics(hass, sub_state, {}) return await async_subscribe_topics(hass, sub_state, {})

View File

@ -15,8 +15,8 @@ import voluptuous as vol
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
from homeassistant.components.mqtt.models import ( from homeassistant.components.mqtt.models import (
Message as MQTTMessage, ReceiveMessage as MQTTReceiveMessage,
PublishPayloadType, ReceivePayloadType,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
@ -188,12 +188,12 @@ async def _get_gateway(
mqtt.async_publish(topic, payload, qos, retain) mqtt.async_publish(topic, payload, qos, retain)
def sub_callback( def sub_callback(
topic: str, sub_cb: Callable[[str, PublishPayloadType, int], None], qos: int topic: str, sub_cb: Callable[[str, ReceivePayloadType, int], None], qos: int
) -> None: ) -> None:
"""Call MQTT subscribe function.""" """Call MQTT subscribe function."""
@callback @callback
def internal_callback(msg: MQTTMessage) -> None: def internal_callback(msg: MQTTReceiveMessage) -> None:
"""Call callback.""" """Call callback."""
sub_cb(msg.topic, msg.payload, msg.qos) sub_cb(msg.topic, msg.payload, msg.qos)

View File

@ -34,7 +34,7 @@ from homeassistant.components.device_automation import ( # noqa: F401
_async_get_device_automation_capabilities as async_get_device_automation_capabilities, _async_get_device_automation_capabilities as async_get_device_automation_capabilities,
_async_get_device_automations as async_get_device_automations, _async_get_device_automations as async_get_device_automations,
) )
from homeassistant.components.mqtt.models import Message from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.const import ( from homeassistant.const import (
DEVICE_DEFAULT_NAME, DEVICE_DEFAULT_NAME,
@ -353,7 +353,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
"""Fire the MQTT message.""" """Fire the MQTT message."""
if isinstance(payload, str): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
msg = Message(topic, payload, qos, retain) msg = ReceiveMessage(topic, payload, qos, retain)
hass.data["mqtt"]._mqtt_handle_message(msg) hass.data["mqtt"]._mqtt_handle_message(msg)

View File

@ -1,6 +1,6 @@
"""Test config flow.""" """Test config flow."""
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.mqtt.models import Message from homeassistant.components.mqtt.models import ReceiveMessage
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -19,7 +19,9 @@ async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock):
async def test_mqtt_abort_invalid_topic(hass, mqtt_mock): async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
"""Check MQTT flow aborts if discovery topic is invalid.""" """Check MQTT flow aborts if discovery topic is invalid."""
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##") discovery_info = ReceiveMessage(
"", "", 0, False, subscribed_topic="custom_prefix/##"
)
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info "tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
) )
@ -29,7 +31,9 @@ async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
async def test_mqtt_setup(hass, mqtt_mock) -> None: async def test_mqtt_setup(hass, mqtt_mock) -> None:
"""Test we can finish a config flow through MQTT with custom prefix.""" """Test we can finish a config flow through MQTT with custom prefix."""
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#") discovery_info = ReceiveMessage(
"", "", 0, False, subscribed_topic="custom_prefix/123/#"
)
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info "tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
) )