mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Minor improvements of MQTT typing (#52578)
* Minor improvements of MQTT typing * Tweak
This commit is contained in:
parent
dc72c6c606
commit
4d32e1ed01
@ -12,7 +12,7 @@ from axis.streammanager import SIGNAL_PLAYING, STATE_STOPPED
|
||||
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
||||
from homeassistant.components.mqtt.models import Message
|
||||
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||
from homeassistant.const import (
|
||||
CONF_HOST,
|
||||
CONF_NAME,
|
||||
@ -195,7 +195,7 @@ class AxisNetworkDevice:
|
||||
)
|
||||
|
||||
@callback
|
||||
def mqtt_message(self, message: Message) -> None:
|
||||
def mqtt_message(self, message: ReceiveMessage) -> None:
|
||||
"""Receive Axis MQTT message."""
|
||||
self.disconnect_from_stream()
|
||||
|
||||
|
@ -75,9 +75,11 @@ from .const import (
|
||||
from .discovery import LAST_DISCOVERY
|
||||
from .models import (
|
||||
AsyncMessageCallbackType,
|
||||
Message,
|
||||
MessageCallbackType,
|
||||
PublishMessage,
|
||||
PublishPayloadType,
|
||||
ReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||
|
||||
@ -290,9 +292,9 @@ def async_publish_template(
|
||||
|
||||
|
||||
AsyncDeprecatedMessageCallbackType = Callable[
|
||||
[str, PublishPayloadType, int], Awaitable[None]
|
||||
[str, ReceivePayloadType, int], Awaitable[None]
|
||||
]
|
||||
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None]
|
||||
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
|
||||
|
||||
|
||||
def wrap_msg_callback(
|
||||
@ -308,7 +310,7 @@ def wrap_msg_callback(
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
|
||||
@wraps(msg_callback)
|
||||
async def async_wrapper(msg: Message) -> None:
|
||||
async def async_wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
|
||||
msg.topic, msg.payload, msg.qos
|
||||
@ -318,7 +320,7 @@ def wrap_msg_callback(
|
||||
else:
|
||||
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: Message) -> None:
|
||||
def wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
||||
|
||||
@ -676,7 +678,7 @@ class MQTT:
|
||||
CONF_WILL_MESSAGE in self.conf
|
||||
and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE]
|
||||
):
|
||||
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
|
||||
will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE])
|
||||
else:
|
||||
will_message = None
|
||||
|
||||
@ -853,7 +855,7 @@ class MQTT:
|
||||
retain=birth_message.retain,
|
||||
)
|
||||
|
||||
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
||||
birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE])
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
publish_birth_message(birth_message), self.hass.loop
|
||||
)
|
||||
@ -900,7 +902,7 @@ class MQTT:
|
||||
|
||||
self.hass.async_run_hass_job(
|
||||
subscription.job,
|
||||
Message(
|
||||
ReceiveMessage(
|
||||
msg.topic,
|
||||
payload,
|
||||
msg.qos,
|
||||
@ -1043,7 +1045,7 @@ async def websocket_subscribe(hass, connection, msg):
|
||||
if not connection.user.is_admin:
|
||||
raise Unauthorized
|
||||
|
||||
async def forward_messages(mqttmsg: Message):
|
||||
async def forward_messages(mqttmsg: ReceiveMessage):
|
||||
"""Forward events to websocket."""
|
||||
connection.send_message(
|
||||
websocket_api.event_message(
|
||||
@ -1064,8 +1066,13 @@ async def websocket_subscribe(hass, connection, msg):
|
||||
connection.send_message(websocket_api.result_message(msg["id"]))
|
||||
|
||||
|
||||
ConnectionStatusCallback = Callable[[bool], None]
|
||||
|
||||
|
||||
@callback
|
||||
def async_subscribe_connection_status(hass, connection_status_callback):
|
||||
def async_subscribe_connection_status(
|
||||
hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribe to MQTT connection changes."""
|
||||
connection_status_callback_job = HassJob(connection_status_callback)
|
||||
|
||||
@ -1092,6 +1099,6 @@ def async_subscribe_connection_status(hass, connection_status_callback):
|
||||
return unsubscribe
|
||||
|
||||
|
||||
def is_connected(hass):
|
||||
def is_connected(hass: HomeAssistant) -> bool:
|
||||
"""Return if MQTT client is connected."""
|
||||
return hass.data[DATA_MQTT].connected
|
||||
|
@ -38,7 +38,7 @@ from .discovery import (
|
||||
clear_discovery_hash,
|
||||
set_discovery_hash,
|
||||
)
|
||||
from .models import Message
|
||||
from .models import ReceiveMessage
|
||||
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
||||
from .util import valid_subscribe_topic
|
||||
|
||||
@ -221,7 +221,7 @@ class MqttAttributes(Entity):
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def attributes_message_received(msg: Message) -> None:
|
||||
def attributes_message_received(msg: ReceiveMessage) -> None:
|
||||
try:
|
||||
payload = msg.payload
|
||||
if attr_tpl is not None:
|
||||
@ -318,7 +318,7 @@ class MqttAvailability(Entity):
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def availability_message_received(msg: Message) -> None:
|
||||
def availability_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle a new received MQTT availability message."""
|
||||
topic = msg.topic
|
||||
if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:
|
||||
|
@ -7,19 +7,30 @@ from typing import Awaitable, Callable, Union
|
||||
import attr
|
||||
|
||||
PublishPayloadType = Union[str, bytes, int, float, None]
|
||||
ReceivePayloadType = Union[str, bytes]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class Message:
|
||||
class PublishMessage:
|
||||
"""MQTT Message."""
|
||||
|
||||
topic: str = attr.ib()
|
||||
payload: PublishPayloadType = attr.ib()
|
||||
qos: int = attr.ib()
|
||||
retain: bool = attr.ib()
|
||||
subscribed_topic: str | None = attr.ib(default=None)
|
||||
timestamp: dt.datetime | None = attr.ib(default=None)
|
||||
|
||||
|
||||
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]]
|
||||
MessageCallbackType = Callable[[Message], None]
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class ReceiveMessage:
|
||||
"""MQTT Message."""
|
||||
|
||||
topic: str = attr.ib()
|
||||
payload: ReceivePayloadType = attr.ib()
|
||||
qos: int = attr.ib()
|
||||
retain: bool = attr.ib()
|
||||
subscribed_topic: str = attr.ib(default=None)
|
||||
timestamp: dt.datetime = attr.ib(default=None)
|
||||
|
||||
|
||||
AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]]
|
||||
MessageCallbackType = Callable[[ReceiveMessage], None]
|
||||
|
@ -66,7 +66,7 @@ async def async_subscribe_topics(
|
||||
hass: HomeAssistant,
|
||||
new_state: dict[str, EntitySubscription] | None,
|
||||
topics: dict[str, Any],
|
||||
):
|
||||
) -> dict[str, EntitySubscription]:
|
||||
"""(Re)Subscribe to a set of MQTT topics.
|
||||
|
||||
State is kept in sub_state and a dictionary mapping from the subscription
|
||||
@ -106,6 +106,8 @@ async def async_subscribe_topics(
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_unsubscribe_topics(hass: HomeAssistant, sub_state: dict):
|
||||
async def async_unsubscribe_topics(
|
||||
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
||||
) -> dict[str, EntitySubscription]:
|
||||
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
||||
return await async_subscribe_topics(hass, sub_state, {})
|
||||
|
@ -15,8 +15,8 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
||||
from homeassistant.components.mqtt.models import (
|
||||
Message as MQTTMessage,
|
||||
PublishPayloadType,
|
||||
ReceiveMessage as MQTTReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
@ -188,12 +188,12 @@ async def _get_gateway(
|
||||
mqtt.async_publish(topic, payload, qos, retain)
|
||||
|
||||
def sub_callback(
|
||||
topic: str, sub_cb: Callable[[str, PublishPayloadType, int], None], qos: int
|
||||
topic: str, sub_cb: Callable[[str, ReceivePayloadType, int], None], qos: int
|
||||
) -> None:
|
||||
"""Call MQTT subscribe function."""
|
||||
|
||||
@callback
|
||||
def internal_callback(msg: MQTTMessage) -> None:
|
||||
def internal_callback(msg: MQTTReceiveMessage) -> None:
|
||||
"""Call callback."""
|
||||
sub_cb(msg.topic, msg.payload, msg.qos)
|
||||
|
||||
|
@ -34,7 +34,7 @@ from homeassistant.components.device_automation import ( # noqa: F401
|
||||
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
|
||||
_async_get_device_automations as async_get_device_automations,
|
||||
)
|
||||
from homeassistant.components.mqtt.models import Message
|
||||
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.const import (
|
||||
DEVICE_DEFAULT_NAME,
|
||||
@ -353,7 +353,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
|
||||
"""Fire the MQTT message."""
|
||||
if isinstance(payload, str):
|
||||
payload = payload.encode("utf-8")
|
||||
msg = Message(topic, payload, qos, retain)
|
||||
msg = ReceiveMessage(topic, payload, qos, retain)
|
||||
hass.data["mqtt"]._mqtt_handle_message(msg)
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test config flow."""
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.mqtt.models import Message
|
||||
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -19,7 +19,9 @@ async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock):
|
||||
|
||||
async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
|
||||
"""Check MQTT flow aborts if discovery topic is invalid."""
|
||||
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##")
|
||||
discovery_info = ReceiveMessage(
|
||||
"", "", 0, False, subscribed_topic="custom_prefix/##"
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
|
||||
)
|
||||
@ -29,7 +31,9 @@ async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
|
||||
|
||||
async def test_mqtt_setup(hass, mqtt_mock) -> None:
|
||||
"""Test we can finish a config flow through MQTT with custom prefix."""
|
||||
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#")
|
||||
discovery_info = ReceiveMessage(
|
||||
"", "", 0, False, subscribed_topic="custom_prefix/123/#"
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user