Create bound callback_message_received method for handling mqtt callbacks (#117951)

* Create bound callback_message_received method for handling mqtt callbacks

* refactor a bit

* fix ruff

* reduce overhead

* cleanup

* cleanup

* Revert changes alarm_control_panel

* Add sensor and binary sensor

* use same pattern for MqttAttributes/MqttAvailability

* remove useless function since we did not need to add to it

* code cleanup

* collapse

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Jan Bouwhuis 2024-05-24 11:18:25 +02:00 committed by GitHub
parent d4df86da06
commit 9333965b23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 210 additions and 152 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
import logging import logging
from typing import Any from typing import Any
@ -37,13 +38,7 @@ from homeassistant.util import dt as dt_util
from . import subscription from . import subscription
from .config import MQTT_RO_SCHEMA from .config import MQTT_RO_SCHEMA
from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE
from .debug_info import log_messages from .mixins import MqttAvailability, MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttAvailability,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttValueTemplate, ReceiveMessage from .models import MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@ -162,92 +157,95 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _off_delay_listener(self, now: datetime) -> None:
"""Switch device off after a delay."""
self._delay_listener = None
self._attr_is_on = False
self.async_write_ha_state()
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle a new received MQTT state message."""
# auto-expire enabled?
if self._expire_after:
# When expire_after is set, and we receive a message, assume device is
# not expired since it has to be to receive the message
self._expired = False
# Reset old trigger
if self._expiration_trigger:
self._expiration_trigger()
# Set new trigger
self._expiration_trigger = async_call_later(
self.hass, self._expire_after, self._value_is_expired
)
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
(
"Empty template output for entity: %s with state topic: %s."
" Payload: '%s', with value template '%s'"
),
self.entity_id,
self._config[CONF_STATE_TOPIC],
msg.payload,
self._config.get(CONF_VALUE_TEMPLATE),
)
return
if payload == self._config[CONF_PAYLOAD_ON]:
self._attr_is_on = True
elif payload == self._config[CONF_PAYLOAD_OFF]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
else: # Payload is not for this entity
template_info = ""
if self._config.get(CONF_VALUE_TEMPLATE) is not None:
template_info = (
f", template output: '{payload!s}', with value template"
f" '{self._config.get(CONF_VALUE_TEMPLATE)!s}'"
)
_LOGGER.info(
(
"No matching payload found for entity: %s with state topic: %s."
" Payload: '%s'%s"
),
self.entity_id,
self._config[CONF_STATE_TOPIC],
msg.payload,
template_info,
)
return
if self._delay_listener is not None:
self._delay_listener()
self._delay_listener = None
off_delay: int | None = self._config.get(CONF_OFF_DELAY)
if self._attr_is_on and off_delay is not None:
self._delay_listener = evt.async_call_later(
self.hass, off_delay, self._off_delay_listener
)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
def off_delay_listener(now: datetime) -> None:
"""Switch device off after a delay."""
self._delay_listener = None
self._attr_is_on = False
self.async_write_ha_state()
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on", "_expired"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT state message."""
# auto-expire enabled?
if self._expire_after:
# When expire_after is set, and we receive a message, assume device is
# not expired since it has to be to receive the message
self._expired = False
# Reset old trigger
if self._expiration_trigger:
self._expiration_trigger()
# Set new trigger
self._expiration_trigger = async_call_later(
self.hass, self._expire_after, self._value_is_expired
)
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
(
"Empty template output for entity: %s with state topic: %s."
" Payload: '%s', with value template '%s'"
),
self.entity_id,
self._config[CONF_STATE_TOPIC],
msg.payload,
self._config.get(CONF_VALUE_TEMPLATE),
)
return
if payload == self._config[CONF_PAYLOAD_ON]:
self._attr_is_on = True
elif payload == self._config[CONF_PAYLOAD_OFF]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
else: # Payload is not for this entity
template_info = ""
if self._config.get(CONF_VALUE_TEMPLATE) is not None:
template_info = (
f", template output: '{payload!s}', with value template"
f" '{self._config.get(CONF_VALUE_TEMPLATE)!s}'"
)
_LOGGER.info(
(
"No matching payload found for entity: %s with state topic: %s."
" Payload: '%s'%s"
),
self.entity_id,
self._config[CONF_STATE_TOPIC],
msg.payload,
template_info,
)
return
if self._delay_listener is not None:
self._delay_listener()
self._delay_listener = None
off_delay: int | None = self._config.get(CONF_OFF_DELAY)
if self._attr_is_on and off_delay is not None:
self._delay_listener = evt.async_call_later(
self.hass, off_delay, off_delay_listener
)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_is_on", "_expired"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }

View File

@ -86,9 +86,12 @@ def add_subscription(
hass: HomeAssistant, hass: HomeAssistant,
message_callback: MessageCallbackType, message_callback: MessageCallbackType,
subscription: str, subscription: str,
entity_id: str | None = None,
) -> None: ) -> None:
"""Prepare debug data for subscription.""" """Prepare debug data for subscription."""
if entity_id := getattr(message_callback, "__entity_id", None): if not entity_id:
entity_id = getattr(message_callback, "__entity_id", None)
if entity_id:
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault( entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
) )
@ -104,9 +107,12 @@ def remove_subscription(
hass: HomeAssistant, hass: HomeAssistant,
message_callback: MessageCallbackType, message_callback: MessageCallbackType,
subscription: str, subscription: str,
entity_id: str | None = None,
) -> None: ) -> None:
"""Remove debug data for subscription if it exists.""" """Remove debug data for subscription if it exists."""
if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in ( if not entity_id:
entity_id = getattr(message_callback, "__entity_id", None)
if entity_id and entity_id in (
debug_info_entities := hass.data[DATA_MQTT].debug_info_entities debug_info_entities := hass.data[DATA_MQTT].debug_info_entities
): ):
debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1 debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1

View File

@ -48,6 +48,7 @@ from homeassistant.helpers.event import (
async_track_entity_registry_updated_event, async_track_entity_registry_updated_event,
) )
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ( from homeassistant.helpers.typing import (
UNDEFINED, UNDEFINED,
ConfigType, ConfigType,
@ -93,7 +94,7 @@ from .const import (
MQTT_CONNECTED, MQTT_CONNECTED,
MQTT_DISCONNECTED, MQTT_DISCONNECTED,
) )
from .debug_info import log_message, log_messages from .debug_info import log_message
from .discovery import ( from .discovery import (
MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW, MQTT_DISCOVERY_NEW,
@ -401,6 +402,7 @@ class MqttAttributes(Entity):
"""Mixin used for platforms that support JSON attributes.""" """Mixin used for platforms that support JSON attributes."""
_attributes_extra_blocked: frozenset[str] = frozenset() _attributes_extra_blocked: frozenset[str] = frozenset()
_attr_tpl: Callable[[ReceivePayloadType], ReceivePayloadType] | None = None
def __init__(self, config: ConfigType) -> None: def __init__(self, config: ConfigType) -> None:
"""Initialize the JSON attributes mixin.""" """Initialize the JSON attributes mixin."""
@ -424,38 +426,21 @@ class MqttAttributes(Entity):
def _attributes_prepare_subscribe_topics(self) -> None: def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
attr_tpl = MqttValueTemplate( self._attr_tpl = MqttValueTemplate(
self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_extra_state_attributes"})
def attributes_message_received(msg: ReceiveMessage) -> None:
"""Update extra state attributes."""
payload = attr_tpl(msg.payload)
try:
json_dict = json_loads(payload) if isinstance(payload, str) else None
if isinstance(json_dict, dict):
filtered_dict = {
k: v
for k, v in json_dict.items()
if k not in MQTT_ATTRIBUTES_BLOCKED
and k not in self._attributes_extra_blocked
}
self._attr_extra_state_attributes = filtered_dict
else:
_LOGGER.warning("JSON result was not a dictionary")
except ValueError:
_LOGGER.warning("Erroneous JSON: %s", payload)
self._attributes_sub_state = async_prepare_subscribe_topics( self._attributes_sub_state = async_prepare_subscribe_topics(
self.hass, self.hass,
self._attributes_sub_state, self._attributes_sub_state,
{ {
CONF_JSON_ATTRS_TOPIC: { CONF_JSON_ATTRS_TOPIC: {
"topic": self._attributes_config.get(CONF_JSON_ATTRS_TOPIC), "topic": self._attributes_config.get(CONF_JSON_ATTRS_TOPIC),
"msg_callback": attributes_message_received, "msg_callback": partial(
self._message_callback, # type: ignore[attr-defined]
self._attributes_message_received,
{"_attr_extra_state_attributes"},
),
"entity_id": self.entity_id,
"qos": self._attributes_config.get(CONF_QOS), "qos": self._attributes_config.get(CONF_QOS),
"encoding": self._attributes_config[CONF_ENCODING] or None, "encoding": self._attributes_config[CONF_ENCODING] or None,
} }
@ -472,6 +457,28 @@ class MqttAttributes(Entity):
self.hass, self._attributes_sub_state self.hass, self._attributes_sub_state
) )
@callback
def _attributes_message_received(self, msg: ReceiveMessage) -> None:
"""Update extra state attributes."""
if TYPE_CHECKING:
assert self._attr_tpl is not None
payload = self._attr_tpl(msg.payload)
try:
json_dict = json_loads(payload) if isinstance(payload, str) else None
except ValueError:
_LOGGER.warning("Erroneous JSON: %s", payload)
else:
if isinstance(json_dict, dict):
filtered_dict = {
k: v
for k, v in json_dict.items()
if k not in MQTT_ATTRIBUTES_BLOCKED
and k not in self._attributes_extra_blocked
}
self._attr_extra_state_attributes = filtered_dict
else:
_LOGGER.warning("JSON result was not a dictionary")
class MqttAvailability(Entity): class MqttAvailability(Entity):
"""Mixin used for platforms that report availability.""" """Mixin used for platforms that report availability."""
@ -535,28 +542,18 @@ class MqttAvailability(Entity):
def _availability_prepare_subscribe_topics(self) -> None: def _availability_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"available"})
def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message."""
topic = msg.topic
payload = self._avail_topics[topic][CONF_AVAILABILITY_TEMPLATE](msg.payload)
if payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:
self._available[topic] = True
self._available_latest = True
elif payload == self._avail_topics[topic][CONF_PAYLOAD_NOT_AVAILABLE]:
self._available[topic] = False
self._available_latest = False
self._available = { self._available = {
topic: (self._available.get(topic, False)) for topic in self._avail_topics topic: (self._available.get(topic, False)) for topic in self._avail_topics
} }
topics: dict[str, dict[str, Any]] = { topics: dict[str, dict[str, Any]] = {
f"availability_{topic}": { f"availability_{topic}": {
"topic": topic, "topic": topic,
"msg_callback": availability_message_received, "msg_callback": partial(
self._message_callback, # type: ignore[attr-defined]
self._availability_message_received,
{"available"},
),
"entity_id": self.entity_id,
"qos": self._avail_config[CONF_QOS], "qos": self._avail_config[CONF_QOS],
"encoding": self._avail_config[CONF_ENCODING] or None, "encoding": self._avail_config[CONF_ENCODING] or None,
} }
@ -569,6 +566,19 @@ class MqttAvailability(Entity):
topics, topics,
) )
@callback
def _availability_message_received(self, msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message."""
topic = msg.topic
avail_topic = self._avail_topics[topic]
payload = avail_topic[CONF_AVAILABILITY_TEMPLATE](msg.payload)
if payload == avail_topic[CONF_PAYLOAD_AVAILABLE]:
self._available[topic] = True
self._available_latest = True
elif payload == avail_topic[CONF_PAYLOAD_NOT_AVAILABLE]:
self._available[topic] = False
self._available_latest = False
async def _availability_subscribe_topics(self) -> None: async def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state) await async_subscribe_topics(self.hass, self._availability_sub_state)
@ -1073,6 +1083,7 @@ class MqttEntity(
): ):
"""Representation of an MQTT entity.""" """Representation of an MQTT entity."""
_attr_force_update = False
_attr_has_entity_name = True _attr_has_entity_name = True
_attr_should_poll = False _attr_should_poll = False
_default_name: str | None _default_name: str | None
@ -1225,6 +1236,45 @@ class MqttEntity(
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
def _attrs_have_changed(
self, attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...]
) -> bool:
"""Return True if attributes on entity changed or if update is forced."""
if self._attr_force_update:
return True
for attribute, last_value in attrs_snapshot:
if getattr(self, attribute, UNDEFINED) != last_value:
return True
return False
@callback
def _message_callback(
self,
msg_callback: MessageCallbackType,
attributes: set[str],
msg: ReceiveMessage,
) -> None:
"""Process the message callback."""
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes
)
mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
msg.subscribed_topic
]["messages"]
if msg not in messages:
messages.append(msg)
try:
msg_callback(msg)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self)
def update_device( def update_device(
hass: HomeAssistant, hass: HomeAssistant,

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
import logging import logging
from typing import Any from typing import Any
@ -40,13 +41,7 @@ from homeassistant.util import dt as dt_util
from . import subscription from . import subscription
from .config import MQTT_RO_SCHEMA from .config import MQTT_RO_SCHEMA
from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE
from .debug_info import log_messages from .mixins import MqttAvailability, MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttAvailability,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttValueTemplate, MqttValueTemplate,
PayloadSentinel, PayloadSentinel,
@ -215,9 +210,9 @@ class MqttSensor(MqttEntity, RestoreSensor):
self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None: @callback
"""(Re)Subscribe to topics.""" def _state_message_received(self, msg: ReceiveMessage) -> None:
topics: dict[str, dict[str, Any]] = {} """Handle new MQTT state messages."""
def _update_state(msg: ReceiveMessage) -> None: def _update_state(msg: ReceiveMessage) -> None:
# auto-expire enabled? # auto-expire enabled?
@ -280,20 +275,22 @@ class MqttSensor(MqttEntity, RestoreSensor):
"Invalid last_reset message '%s' from '%s'", msg.payload, msg.topic "Invalid last_reset message '%s' from '%s'", msg.payload, msg.topic
) )
@callback _update_state(msg)
@write_state_on_attr_change( if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
self, {"_attr_native_value", "_attr_last_reset", "_expired"} _update_last_reset(msg)
)
@log_messages(self.hass, self.entity_id) def _prepare_subscribe_topics(self) -> None:
def message_received(msg: ReceiveMessage) -> None: """(Re)Subscribe to topics."""
"""Handle new MQTT messages.""" topics: dict[str, dict[str, Any]] = {}
_update_state(msg)
if CONF_LAST_RESET_VALUE_TEMPLATE in self._config:
_update_last_reset(msg)
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_native_value", "_attr_last_reset", "_expired"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }

View File

@ -26,6 +26,7 @@ class EntitySubscription:
unsubscribe_callback: Callable[[], None] | None = attr.ib() unsubscribe_callback: Callable[[], None] | None = attr.ib()
qos: int = attr.ib(default=0) qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8") encoding: str = attr.ib(default="utf-8")
entity_id: str | None = attr.ib(default=None)
def resubscribe_if_necessary( def resubscribe_if_necessary(
self, hass: HomeAssistant, other: EntitySubscription | None self, hass: HomeAssistant, other: EntitySubscription | None
@ -41,7 +42,7 @@ class EntitySubscription:
other.unsubscribe_callback() other.unsubscribe_callback()
# Clear debug data if it exists # Clear debug data if it exists
debug_info.remove_subscription( debug_info.remove_subscription(
self.hass, other.message_callback, str(other.topic) self.hass, other.message_callback, str(other.topic), other.entity_id
) )
if self.topic is None: if self.topic is None:
@ -49,7 +50,9 @@ class EntitySubscription:
return return
# Prepare debug data # Prepare debug data
debug_info.add_subscription(self.hass, self.message_callback, self.topic) debug_info.add_subscription(
self.hass, self.message_callback, self.topic, self.entity_id
)
self.subscribe_task = mqtt.async_subscribe( self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding hass, self.topic, self.message_callback, self.qos, self.encoding
@ -80,7 +83,7 @@ class EntitySubscription:
def async_prepare_subscribe_topics( def async_prepare_subscribe_topics(
hass: HomeAssistant, hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None, new_state: dict[str, EntitySubscription] | None,
topics: dict[str, Any], topics: dict[str, dict[str, Any]],
) -> dict[str, EntitySubscription]: ) -> dict[str, EntitySubscription]:
"""Prepare (re)subscribe to a set of MQTT topics. """Prepare (re)subscribe to a set of MQTT topics.
@ -106,6 +109,7 @@ def async_prepare_subscribe_topics(
encoding=value.get("encoding", "utf-8"), encoding=value.get("encoding", "utf-8"),
hass=hass, hass=hass,
subscribe_task=None, subscribe_task=None,
entity_id=value.get("entity_id", None),
) )
# Get the current subscription state # Get the current subscription state
current = current_subscriptions.pop(key, None) current = current_subscriptions.pop(key, None)
@ -118,7 +122,10 @@ def async_prepare_subscribe_topics(
remaining.unsubscribe_callback() remaining.unsubscribe_callback()
# Clear debug data if it exists # Clear debug data if it exists
debug_info.remove_subscription( debug_info.remove_subscription(
hass, remaining.message_callback, str(remaining.topic) hass,
remaining.message_callback,
str(remaining.topic),
remaining.entity_id,
) )
return new_state return new_state