From 9333965b23e8367bccbc006b117d75c118dd4882 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Fri, 24 May 2024 11:18:25 +0200 Subject: [PATCH] Create bound callback_message_received method for handling mqtt callbacks (#117951) * Create bound callback_message_received method for handling mqtt callbacks * refactor a bit * fix ruff * reduce overhead * cleanup * cleanup * Revert changes alarm_control_panel * Add sensor and binary sensor * use same pattern for MqttAttributes/MqttAvailability * remove useless function since we did not need to add to it * code cleanup * collapse --------- Co-authored-by: J. Nick Koston --- .../components/mqtt/binary_sensor.py | 166 +++++++++--------- homeassistant/components/mqtt/debug_info.py | 10 +- homeassistant/components/mqtt/mixins.py | 132 +++++++++----- homeassistant/components/mqtt/sensor.py | 39 ++-- homeassistant/components/mqtt/subscription.py | 15 +- 5 files changed, 210 insertions(+), 152 deletions(-) diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index cfc130377eb..68f0ab10a45 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from functools import partial import logging from typing import Any @@ -37,13 +38,7 @@ from homeassistant.util import dt as dt_util from . import subscription from .config import MQTT_RO_SCHEMA from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE -from .debug_info import log_messages -from .mixins import ( - MqttAvailability, - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttAvailability, MqttEntity, async_setup_entity_entry_helper from .models import MqttValueTemplate, ReceiveMessage from .schemas import MQTT_ENTITY_COMMON_SCHEMA @@ -162,92 +157,95 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): entity=self, ).async_render_with_possible_json_value + @callback + def _off_delay_listener(self, now: datetime) -> None: + """Switch device off after a delay.""" + self._delay_listener = None + self._attr_is_on = False + self.async_write_ha_state() + + def _state_message_received(self, msg: ReceiveMessage) -> None: + """Handle a new received MQTT state message.""" + + # auto-expire enabled? + if self._expire_after: + # When expire_after is set, and we receive a message, assume device is + # not expired since it has to be to receive the message + self._expired = False + + # Reset old trigger + if self._expiration_trigger: + self._expiration_trigger() + + # Set new trigger + self._expiration_trigger = async_call_later( + self.hass, self._expire_after, self._value_is_expired + ) + + payload = self._value_template(msg.payload) + if not payload.strip(): # No output from template, ignore + _LOGGER.debug( + ( + "Empty template output for entity: %s with state topic: %s." + " Payload: '%s', with value template '%s'" + ), + self.entity_id, + self._config[CONF_STATE_TOPIC], + msg.payload, + self._config.get(CONF_VALUE_TEMPLATE), + ) + return + + if payload == self._config[CONF_PAYLOAD_ON]: + self._attr_is_on = True + elif payload == self._config[CONF_PAYLOAD_OFF]: + self._attr_is_on = False + elif payload == PAYLOAD_NONE: + self._attr_is_on = None + else: # Payload is not for this entity + template_info = "" + if self._config.get(CONF_VALUE_TEMPLATE) is not None: + template_info = ( + f", template output: '{payload!s}', with value template" + f" '{self._config.get(CONF_VALUE_TEMPLATE)!s}'" + ) + _LOGGER.info( + ( + "No matching payload found for entity: %s with state topic: %s." + " Payload: '%s'%s" + ), + self.entity_id, + self._config[CONF_STATE_TOPIC], + msg.payload, + template_info, + ) + return + + if self._delay_listener is not None: + self._delay_listener() + self._delay_listener = None + + off_delay: int | None = self._config.get(CONF_OFF_DELAY) + if self._attr_is_on and off_delay is not None: + self._delay_listener = evt.async_call_later( + self.hass, off_delay, self._off_delay_listener + ) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - @callback - def off_delay_listener(now: datetime) -> None: - """Switch device off after a delay.""" - self._delay_listener = None - self._attr_is_on = False - self.async_write_ha_state() - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_is_on", "_expired"}) - def state_message_received(msg: ReceiveMessage) -> None: - """Handle a new received MQTT state message.""" - # auto-expire enabled? - if self._expire_after: - # When expire_after is set, and we receive a message, assume device is - # not expired since it has to be to receive the message - self._expired = False - - # Reset old trigger - if self._expiration_trigger: - self._expiration_trigger() - - # Set new trigger - self._expiration_trigger = async_call_later( - self.hass, self._expire_after, self._value_is_expired - ) - - payload = self._value_template(msg.payload) - if not payload.strip(): # No output from template, ignore - _LOGGER.debug( - ( - "Empty template output for entity: %s with state topic: %s." - " Payload: '%s', with value template '%s'" - ), - self.entity_id, - self._config[CONF_STATE_TOPIC], - msg.payload, - self._config.get(CONF_VALUE_TEMPLATE), - ) - return - - if payload == self._config[CONF_PAYLOAD_ON]: - self._attr_is_on = True - elif payload == self._config[CONF_PAYLOAD_OFF]: - self._attr_is_on = False - elif payload == PAYLOAD_NONE: - self._attr_is_on = None - else: # Payload is not for this entity - template_info = "" - if self._config.get(CONF_VALUE_TEMPLATE) is not None: - template_info = ( - f", template output: '{payload!s}', with value template" - f" '{self._config.get(CONF_VALUE_TEMPLATE)!s}'" - ) - _LOGGER.info( - ( - "No matching payload found for entity: %s with state topic: %s." - " Payload: '%s'%s" - ), - self.entity_id, - self._config[CONF_STATE_TOPIC], - msg.payload, - template_info, - ) - return - - if self._delay_listener is not None: - self._delay_listener() - self._delay_listener = None - - off_delay: int | None = self._config.get(CONF_OFF_DELAY) - if self._attr_is_on and off_delay is not None: - self._delay_listener = evt.async_call_later( - self.hass, off_delay, off_delay_listener - ) - self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { "state_topic": { "topic": self._config[CONF_STATE_TOPIC], - "msg_callback": state_message_received, + "msg_callback": partial( + self._message_callback, + self._state_message_received, + {"_attr_is_on", "_expired"}, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } diff --git a/homeassistant/components/mqtt/debug_info.py b/homeassistant/components/mqtt/debug_info.py index bc1eddeef97..72bf1596164 100644 --- a/homeassistant/components/mqtt/debug_info.py +++ b/homeassistant/components/mqtt/debug_info.py @@ -86,9 +86,12 @@ def add_subscription( hass: HomeAssistant, message_callback: MessageCallbackType, subscription: str, + entity_id: str | None = None, ) -> None: """Prepare debug data for subscription.""" - if entity_id := getattr(message_callback, "__entity_id", None): + if not entity_id: + entity_id = getattr(message_callback, "__entity_id", None) + if entity_id: entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault( entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} ) @@ -104,9 +107,12 @@ def remove_subscription( hass: HomeAssistant, message_callback: MessageCallbackType, subscription: str, + entity_id: str | None = None, ) -> None: """Remove debug data for subscription if it exists.""" - if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in ( + if not entity_id: + entity_id = getattr(message_callback, "__entity_id", None) + if entity_id and entity_id in ( debug_info_entities := hass.data[DATA_MQTT].debug_info_entities ): debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1 diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 56bbc7b19eb..bc70c07a3fe 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -48,6 +48,7 @@ from homeassistant.helpers.event import ( async_track_entity_registry_updated_event, ) from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue +from homeassistant.helpers.service_info.mqtt import ReceivePayloadType from homeassistant.helpers.typing import ( UNDEFINED, ConfigType, @@ -93,7 +94,7 @@ from .const import ( MQTT_CONNECTED, MQTT_DISCONNECTED, ) -from .debug_info import log_message, log_messages +from .debug_info import log_message from .discovery import ( MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_NEW, @@ -401,6 +402,7 @@ class MqttAttributes(Entity): """Mixin used for platforms that support JSON attributes.""" _attributes_extra_blocked: frozenset[str] = frozenset() + _attr_tpl: Callable[[ReceivePayloadType], ReceivePayloadType] | None = None def __init__(self, config: ConfigType) -> None: """Initialize the JSON attributes mixin.""" @@ -424,38 +426,21 @@ class MqttAttributes(Entity): def _attributes_prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - attr_tpl = MqttValueTemplate( + self._attr_tpl = MqttValueTemplate( self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self ).async_render_with_possible_json_value - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_extra_state_attributes"}) - def attributes_message_received(msg: ReceiveMessage) -> None: - """Update extra state attributes.""" - payload = attr_tpl(msg.payload) - try: - json_dict = json_loads(payload) if isinstance(payload, str) else None - if isinstance(json_dict, dict): - filtered_dict = { - k: v - for k, v in json_dict.items() - if k not in MQTT_ATTRIBUTES_BLOCKED - and k not in self._attributes_extra_blocked - } - self._attr_extra_state_attributes = filtered_dict - else: - _LOGGER.warning("JSON result was not a dictionary") - except ValueError: - _LOGGER.warning("Erroneous JSON: %s", payload) - self._attributes_sub_state = async_prepare_subscribe_topics( self.hass, self._attributes_sub_state, { CONF_JSON_ATTRS_TOPIC: { "topic": self._attributes_config.get(CONF_JSON_ATTRS_TOPIC), - "msg_callback": attributes_message_received, + "msg_callback": partial( + self._message_callback, # type: ignore[attr-defined] + self._attributes_message_received, + {"_attr_extra_state_attributes"}, + ), + "entity_id": self.entity_id, "qos": self._attributes_config.get(CONF_QOS), "encoding": self._attributes_config[CONF_ENCODING] or None, } @@ -472,6 +457,28 @@ class MqttAttributes(Entity): self.hass, self._attributes_sub_state ) + @callback + def _attributes_message_received(self, msg: ReceiveMessage) -> None: + """Update extra state attributes.""" + if TYPE_CHECKING: + assert self._attr_tpl is not None + payload = self._attr_tpl(msg.payload) + try: + json_dict = json_loads(payload) if isinstance(payload, str) else None + except ValueError: + _LOGGER.warning("Erroneous JSON: %s", payload) + else: + if isinstance(json_dict, dict): + filtered_dict = { + k: v + for k, v in json_dict.items() + if k not in MQTT_ATTRIBUTES_BLOCKED + and k not in self._attributes_extra_blocked + } + self._attr_extra_state_attributes = filtered_dict + else: + _LOGGER.warning("JSON result was not a dictionary") + class MqttAvailability(Entity): """Mixin used for platforms that report availability.""" @@ -535,28 +542,18 @@ class MqttAvailability(Entity): def _availability_prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"available"}) - def availability_message_received(msg: ReceiveMessage) -> None: - """Handle a new received MQTT availability message.""" - topic = msg.topic - payload = self._avail_topics[topic][CONF_AVAILABILITY_TEMPLATE](msg.payload) - if payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]: - self._available[topic] = True - self._available_latest = True - elif payload == self._avail_topics[topic][CONF_PAYLOAD_NOT_AVAILABLE]: - self._available[topic] = False - self._available_latest = False - self._available = { topic: (self._available.get(topic, False)) for topic in self._avail_topics } topics: dict[str, dict[str, Any]] = { f"availability_{topic}": { "topic": topic, - "msg_callback": availability_message_received, + "msg_callback": partial( + self._message_callback, # type: ignore[attr-defined] + self._availability_message_received, + {"available"}, + ), + "entity_id": self.entity_id, "qos": self._avail_config[CONF_QOS], "encoding": self._avail_config[CONF_ENCODING] or None, } @@ -569,6 +566,19 @@ class MqttAvailability(Entity): topics, ) + @callback + def _availability_message_received(self, msg: ReceiveMessage) -> None: + """Handle a new received MQTT availability message.""" + topic = msg.topic + avail_topic = self._avail_topics[topic] + payload = avail_topic[CONF_AVAILABILITY_TEMPLATE](msg.payload) + if payload == avail_topic[CONF_PAYLOAD_AVAILABLE]: + self._available[topic] = True + self._available_latest = True + elif payload == avail_topic[CONF_PAYLOAD_NOT_AVAILABLE]: + self._available[topic] = False + self._available_latest = False + async def _availability_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await async_subscribe_topics(self.hass, self._availability_sub_state) @@ -1073,6 +1083,7 @@ class MqttEntity( ): """Representation of an MQTT entity.""" + _attr_force_update = False _attr_has_entity_name = True _attr_should_poll = False _default_name: str | None @@ -1225,6 +1236,45 @@ class MqttEntity( async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" + @callback + def _attrs_have_changed( + self, attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] + ) -> bool: + """Return True if attributes on entity changed or if update is forced.""" + if self._attr_force_update: + return True + for attribute, last_value in attrs_snapshot: + if getattr(self, attribute, UNDEFINED) != last_value: + return True + return False + + @callback + def _message_callback( + self, + msg_callback: MessageCallbackType, + attributes: set[str], + msg: ReceiveMessage, + ) -> None: + """Process the message callback.""" + attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( + (attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes + ) + mqtt_data = self.hass.data[DATA_MQTT] + messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][ + msg.subscribed_topic + ]["messages"] + if msg not in messages: + messages.append(msg) + + try: + msg_callback(msg) + except MqttValueTemplateException as exc: + _LOGGER.warning(exc) + return + + if self._attrs_have_changed(attrs_snapshot): + mqtt_data.state_write_requests.write_state_request(self) + def update_device( hass: HomeAssistant, diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 744d7e0fdc9..cc0e8c92011 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Callable from datetime import datetime, timedelta +from functools import partial import logging from typing import Any @@ -40,13 +41,7 @@ from homeassistant.util import dt as dt_util from . import subscription from .config import MQTT_RO_SCHEMA from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE -from .debug_info import log_messages -from .mixins import ( - MqttAvailability, - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttAvailability, MqttEntity, async_setup_entity_entry_helper from .models import ( MqttValueTemplate, PayloadSentinel, @@ -215,9 +210,9 @@ class MqttSensor(MqttEntity, RestoreSensor): self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self) -> None: - """(Re)Subscribe to topics.""" - topics: dict[str, dict[str, Any]] = {} + @callback + def _state_message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT state messages.""" def _update_state(msg: ReceiveMessage) -> None: # auto-expire enabled? @@ -280,20 +275,22 @@ class MqttSensor(MqttEntity, RestoreSensor): "Invalid last_reset message '%s' from '%s'", msg.payload, msg.topic ) - @callback - @write_state_on_attr_change( - self, {"_attr_native_value", "_attr_last_reset", "_expired"} - ) - @log_messages(self.hass, self.entity_id) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - _update_state(msg) - if CONF_LAST_RESET_VALUE_TEMPLATE in self._config: - _update_last_reset(msg) + _update_state(msg) + if CONF_LAST_RESET_VALUE_TEMPLATE in self._config: + _update_last_reset(msg) + + def _prepare_subscribe_topics(self) -> None: + """(Re)Subscribe to topics.""" + topics: dict[str, dict[str, Any]] = {} topics["state_topic"] = { "topic": self._config[CONF_STATE_TOPIC], - "msg_callback": message_received, + "msg_callback": partial( + self._message_callback, + self._state_message_received, + {"_attr_native_value", "_attr_last_reset", "_expired"}, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index 14f2999fa9c..6a8b019aee1 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -26,6 +26,7 @@ class EntitySubscription: unsubscribe_callback: Callable[[], None] | None = attr.ib() qos: int = attr.ib(default=0) encoding: str = attr.ib(default="utf-8") + entity_id: str | None = attr.ib(default=None) def resubscribe_if_necessary( self, hass: HomeAssistant, other: EntitySubscription | None @@ -41,7 +42,7 @@ class EntitySubscription: other.unsubscribe_callback() # Clear debug data if it exists debug_info.remove_subscription( - self.hass, other.message_callback, str(other.topic) + self.hass, other.message_callback, str(other.topic), other.entity_id ) if self.topic is None: @@ -49,7 +50,9 @@ class EntitySubscription: return # Prepare debug data - debug_info.add_subscription(self.hass, self.message_callback, self.topic) + debug_info.add_subscription( + self.hass, self.message_callback, self.topic, self.entity_id + ) self.subscribe_task = mqtt.async_subscribe( hass, self.topic, self.message_callback, self.qos, self.encoding @@ -80,7 +83,7 @@ class EntitySubscription: def async_prepare_subscribe_topics( hass: HomeAssistant, new_state: dict[str, EntitySubscription] | None, - topics: dict[str, Any], + topics: dict[str, dict[str, Any]], ) -> dict[str, EntitySubscription]: """Prepare (re)subscribe to a set of MQTT topics. @@ -106,6 +109,7 @@ def async_prepare_subscribe_topics( encoding=value.get("encoding", "utf-8"), hass=hass, subscribe_task=None, + entity_id=value.get("entity_id", None), ) # Get the current subscription state current = current_subscriptions.pop(key, None) @@ -118,7 +122,10 @@ def async_prepare_subscribe_topics( remaining.unsubscribe_callback() # Clear debug data if it exists debug_info.remove_subscription( - hass, remaining.message_callback, str(remaining.topic) + hass, + remaining.message_callback, + str(remaining.topic), + remaining.entity_id, ) return new_state