mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Minor improvements of MQTT typing (#52578)
* Minor improvements of MQTT typing * Tweak
This commit is contained in:
parent
dc72c6c606
commit
4d32e1ed01
@ -12,7 +12,7 @@ from axis.streammanager import SIGNAL_PLAYING, STATE_STOPPED
|
|||||||
|
|
||||||
from homeassistant.components import mqtt
|
from homeassistant.components import mqtt
|
||||||
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
||||||
from homeassistant.components.mqtt.models import Message
|
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_HOST,
|
CONF_HOST,
|
||||||
CONF_NAME,
|
CONF_NAME,
|
||||||
@ -195,7 +195,7 @@ class AxisNetworkDevice:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def mqtt_message(self, message: Message) -> None:
|
def mqtt_message(self, message: ReceiveMessage) -> None:
|
||||||
"""Receive Axis MQTT message."""
|
"""Receive Axis MQTT message."""
|
||||||
self.disconnect_from_stream()
|
self.disconnect_from_stream()
|
||||||
|
|
||||||
|
@ -75,9 +75,11 @@ from .const import (
|
|||||||
from .discovery import LAST_DISCOVERY
|
from .discovery import LAST_DISCOVERY
|
||||||
from .models import (
|
from .models import (
|
||||||
AsyncMessageCallbackType,
|
AsyncMessageCallbackType,
|
||||||
Message,
|
|
||||||
MessageCallbackType,
|
MessageCallbackType,
|
||||||
|
PublishMessage,
|
||||||
PublishPayloadType,
|
PublishPayloadType,
|
||||||
|
ReceiveMessage,
|
||||||
|
ReceivePayloadType,
|
||||||
)
|
)
|
||||||
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||||
|
|
||||||
@ -290,9 +292,9 @@ def async_publish_template(
|
|||||||
|
|
||||||
|
|
||||||
AsyncDeprecatedMessageCallbackType = Callable[
|
AsyncDeprecatedMessageCallbackType = Callable[
|
||||||
[str, PublishPayloadType, int], Awaitable[None]
|
[str, ReceivePayloadType, int], Awaitable[None]
|
||||||
]
|
]
|
||||||
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None]
|
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
|
||||||
|
|
||||||
|
|
||||||
def wrap_msg_callback(
|
def wrap_msg_callback(
|
||||||
@ -308,7 +310,7 @@ def wrap_msg_callback(
|
|||||||
if asyncio.iscoroutinefunction(check_func):
|
if asyncio.iscoroutinefunction(check_func):
|
||||||
|
|
||||||
@wraps(msg_callback)
|
@wraps(msg_callback)
|
||||||
async def async_wrapper(msg: Message) -> None:
|
async def async_wrapper(msg: ReceiveMessage) -> None:
|
||||||
"""Call with deprecated signature."""
|
"""Call with deprecated signature."""
|
||||||
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
|
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
|
||||||
msg.topic, msg.payload, msg.qos
|
msg.topic, msg.payload, msg.qos
|
||||||
@ -318,7 +320,7 @@ def wrap_msg_callback(
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
@wraps(msg_callback)
|
@wraps(msg_callback)
|
||||||
def wrapper(msg: Message) -> None:
|
def wrapper(msg: ReceiveMessage) -> None:
|
||||||
"""Call with deprecated signature."""
|
"""Call with deprecated signature."""
|
||||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
msg_callback(msg.topic, msg.payload, msg.qos)
|
||||||
|
|
||||||
@ -676,7 +678,7 @@ class MQTT:
|
|||||||
CONF_WILL_MESSAGE in self.conf
|
CONF_WILL_MESSAGE in self.conf
|
||||||
and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE]
|
and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE]
|
||||||
):
|
):
|
||||||
will_message = Message(**self.conf[CONF_WILL_MESSAGE])
|
will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE])
|
||||||
else:
|
else:
|
||||||
will_message = None
|
will_message = None
|
||||||
|
|
||||||
@ -853,7 +855,7 @@ class MQTT:
|
|||||||
retain=birth_message.retain,
|
retain=birth_message.retain,
|
||||||
)
|
)
|
||||||
|
|
||||||
birth_message = Message(**self.conf[CONF_BIRTH_MESSAGE])
|
birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE])
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
publish_birth_message(birth_message), self.hass.loop
|
publish_birth_message(birth_message), self.hass.loop
|
||||||
)
|
)
|
||||||
@ -900,7 +902,7 @@ class MQTT:
|
|||||||
|
|
||||||
self.hass.async_run_hass_job(
|
self.hass.async_run_hass_job(
|
||||||
subscription.job,
|
subscription.job,
|
||||||
Message(
|
ReceiveMessage(
|
||||||
msg.topic,
|
msg.topic,
|
||||||
payload,
|
payload,
|
||||||
msg.qos,
|
msg.qos,
|
||||||
@ -1043,7 +1045,7 @@ async def websocket_subscribe(hass, connection, msg):
|
|||||||
if not connection.user.is_admin:
|
if not connection.user.is_admin:
|
||||||
raise Unauthorized
|
raise Unauthorized
|
||||||
|
|
||||||
async def forward_messages(mqttmsg: Message):
|
async def forward_messages(mqttmsg: ReceiveMessage):
|
||||||
"""Forward events to websocket."""
|
"""Forward events to websocket."""
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
websocket_api.event_message(
|
websocket_api.event_message(
|
||||||
@ -1064,8 +1066,13 @@ async def websocket_subscribe(hass, connection, msg):
|
|||||||
connection.send_message(websocket_api.result_message(msg["id"]))
|
connection.send_message(websocket_api.result_message(msg["id"]))
|
||||||
|
|
||||||
|
|
||||||
|
ConnectionStatusCallback = Callable[[bool], None]
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_subscribe_connection_status(hass, connection_status_callback):
|
def async_subscribe_connection_status(
|
||||||
|
hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
|
||||||
|
) -> Callable[[], None]:
|
||||||
"""Subscribe to MQTT connection changes."""
|
"""Subscribe to MQTT connection changes."""
|
||||||
connection_status_callback_job = HassJob(connection_status_callback)
|
connection_status_callback_job = HassJob(connection_status_callback)
|
||||||
|
|
||||||
@ -1092,6 +1099,6 @@ def async_subscribe_connection_status(hass, connection_status_callback):
|
|||||||
return unsubscribe
|
return unsubscribe
|
||||||
|
|
||||||
|
|
||||||
def is_connected(hass):
|
def is_connected(hass: HomeAssistant) -> bool:
|
||||||
"""Return if MQTT client is connected."""
|
"""Return if MQTT client is connected."""
|
||||||
return hass.data[DATA_MQTT].connected
|
return hass.data[DATA_MQTT].connected
|
||||||
|
@ -38,7 +38,7 @@ from .discovery import (
|
|||||||
clear_discovery_hash,
|
clear_discovery_hash,
|
||||||
set_discovery_hash,
|
set_discovery_hash,
|
||||||
)
|
)
|
||||||
from .models import Message
|
from .models import ReceiveMessage
|
||||||
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
from .subscription import async_subscribe_topics, async_unsubscribe_topics
|
||||||
from .util import valid_subscribe_topic
|
from .util import valid_subscribe_topic
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class MqttAttributes(Entity):
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def attributes_message_received(msg: Message) -> None:
|
def attributes_message_received(msg: ReceiveMessage) -> None:
|
||||||
try:
|
try:
|
||||||
payload = msg.payload
|
payload = msg.payload
|
||||||
if attr_tpl is not None:
|
if attr_tpl is not None:
|
||||||
@ -318,7 +318,7 @@ class MqttAvailability(Entity):
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def availability_message_received(msg: Message) -> None:
|
def availability_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle a new received MQTT availability message."""
|
"""Handle a new received MQTT availability message."""
|
||||||
topic = msg.topic
|
topic = msg.topic
|
||||||
if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:
|
if msg.payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:
|
||||||
|
@ -7,19 +7,30 @@ from typing import Awaitable, Callable, Union
|
|||||||
import attr
|
import attr
|
||||||
|
|
||||||
PublishPayloadType = Union[str, bytes, int, float, None]
|
PublishPayloadType = Union[str, bytes, int, float, None]
|
||||||
|
ReceivePayloadType = Union[str, bytes]
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class Message:
|
class PublishMessage:
|
||||||
"""MQTT Message."""
|
"""MQTT Message."""
|
||||||
|
|
||||||
topic: str = attr.ib()
|
topic: str = attr.ib()
|
||||||
payload: PublishPayloadType = attr.ib()
|
payload: PublishPayloadType = attr.ib()
|
||||||
qos: int = attr.ib()
|
qos: int = attr.ib()
|
||||||
retain: bool = attr.ib()
|
retain: bool = attr.ib()
|
||||||
subscribed_topic: str | None = attr.ib(default=None)
|
|
||||||
timestamp: dt.datetime | None = attr.ib(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]]
|
@attr.s(slots=True, frozen=True)
|
||||||
MessageCallbackType = Callable[[Message], None]
|
class ReceiveMessage:
|
||||||
|
"""MQTT Message."""
|
||||||
|
|
||||||
|
topic: str = attr.ib()
|
||||||
|
payload: ReceivePayloadType = attr.ib()
|
||||||
|
qos: int = attr.ib()
|
||||||
|
retain: bool = attr.ib()
|
||||||
|
subscribed_topic: str = attr.ib(default=None)
|
||||||
|
timestamp: dt.datetime = attr.ib(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]]
|
||||||
|
MessageCallbackType = Callable[[ReceiveMessage], None]
|
||||||
|
@ -66,7 +66,7 @@ async def async_subscribe_topics(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
new_state: dict[str, EntitySubscription] | None,
|
new_state: dict[str, EntitySubscription] | None,
|
||||||
topics: dict[str, Any],
|
topics: dict[str, Any],
|
||||||
):
|
) -> dict[str, EntitySubscription]:
|
||||||
"""(Re)Subscribe to a set of MQTT topics.
|
"""(Re)Subscribe to a set of MQTT topics.
|
||||||
|
|
||||||
State is kept in sub_state and a dictionary mapping from the subscription
|
State is kept in sub_state and a dictionary mapping from the subscription
|
||||||
@ -106,6 +106,8 @@ async def async_subscribe_topics(
|
|||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_unsubscribe_topics(hass: HomeAssistant, sub_state: dict):
|
async def async_unsubscribe_topics(
|
||||||
|
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
||||||
|
) -> dict[str, EntitySubscription]:
|
||||||
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
||||||
return await async_subscribe_topics(hass, sub_state, {})
|
return await async_subscribe_topics(hass, sub_state, {})
|
||||||
|
@ -15,8 +15,8 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
||||||
from homeassistant.components.mqtt.models import (
|
from homeassistant.components.mqtt.models import (
|
||||||
Message as MQTTMessage,
|
ReceiveMessage as MQTTReceiveMessage,
|
||||||
PublishPayloadType,
|
ReceivePayloadType,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
@ -188,12 +188,12 @@ async def _get_gateway(
|
|||||||
mqtt.async_publish(topic, payload, qos, retain)
|
mqtt.async_publish(topic, payload, qos, retain)
|
||||||
|
|
||||||
def sub_callback(
|
def sub_callback(
|
||||||
topic: str, sub_cb: Callable[[str, PublishPayloadType, int], None], qos: int
|
topic: str, sub_cb: Callable[[str, ReceivePayloadType, int], None], qos: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Call MQTT subscribe function."""
|
"""Call MQTT subscribe function."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def internal_callback(msg: MQTTMessage) -> None:
|
def internal_callback(msg: MQTTReceiveMessage) -> None:
|
||||||
"""Call callback."""
|
"""Call callback."""
|
||||||
sub_cb(msg.topic, msg.payload, msg.qos)
|
sub_cb(msg.topic, msg.payload, msg.qos)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ from homeassistant.components.device_automation import ( # noqa: F401
|
|||||||
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
|
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
|
||||||
_async_get_device_automations as async_get_device_automations,
|
_async_get_device_automations as async_get_device_automations,
|
||||||
)
|
)
|
||||||
from homeassistant.components.mqtt.models import Message
|
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||||
from homeassistant.config import async_process_component_config
|
from homeassistant.config import async_process_component_config
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
DEVICE_DEFAULT_NAME,
|
DEVICE_DEFAULT_NAME,
|
||||||
@ -353,7 +353,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
|
|||||||
"""Fire the MQTT message."""
|
"""Fire the MQTT message."""
|
||||||
if isinstance(payload, str):
|
if isinstance(payload, str):
|
||||||
payload = payload.encode("utf-8")
|
payload = payload.encode("utf-8")
|
||||||
msg = Message(topic, payload, qos, retain)
|
msg = ReceiveMessage(topic, payload, qos, retain)
|
||||||
hass.data["mqtt"]._mqtt_handle_message(msg)
|
hass.data["mqtt"]._mqtt_handle_message(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Test config flow."""
|
"""Test config flow."""
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.mqtt.models import Message
|
from homeassistant.components.mqtt.models import ReceiveMessage
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -19,7 +19,9 @@ async def test_mqtt_abort_if_existing_entry(hass, mqtt_mock):
|
|||||||
|
|
||||||
async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
|
async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
|
||||||
"""Check MQTT flow aborts if discovery topic is invalid."""
|
"""Check MQTT flow aborts if discovery topic is invalid."""
|
||||||
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/##")
|
discovery_info = ReceiveMessage(
|
||||||
|
"", "", 0, False, subscribed_topic="custom_prefix/##"
|
||||||
|
)
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
|
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
|
||||||
)
|
)
|
||||||
@ -29,7 +31,9 @@ async def test_mqtt_abort_invalid_topic(hass, mqtt_mock):
|
|||||||
|
|
||||||
async def test_mqtt_setup(hass, mqtt_mock) -> None:
|
async def test_mqtt_setup(hass, mqtt_mock) -> None:
|
||||||
"""Test we can finish a config flow through MQTT with custom prefix."""
|
"""Test we can finish a config flow through MQTT with custom prefix."""
|
||||||
discovery_info = Message("", "", 0, False, subscribed_topic="custom_prefix/123/#")
|
discovery_info = ReceiveMessage(
|
||||||
|
"", "", 0, False, subscribed_topic="custom_prefix/123/#"
|
||||||
|
)
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
|
"tasmota", context={"source": config_entries.SOURCE_MQTT}, data=discovery_info
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user