Rework mqtt callbacks for camera, image and event (#118109)

This commit is contained in:
Jan Bouwhuis 2024-05-25 23:23:45 +02:00 committed by GitHub
parent ae0c00218a
commit f21c0679b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 148 additions and 143 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from base64 import b64decode from base64 import b64decode
from functools import partial
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_QOS, CONF_TOPIC from .const import CONF_QOS, CONF_TOPIC
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ReceiveMessage from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@ -97,27 +97,31 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
@callback
def _image_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config[CONF_TOPIC], "topic": self._config[CONF_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._image_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": None, "encoding": None,
} }

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@ -31,7 +32,6 @@ from .const import (
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@ -113,90 +113,91 @@ class MqttEvent(MqttEntity, EventEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _event_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if msg.retain:
_LOGGER.debug(
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
msg.payload,
msg.topic,
)
return
event_attributes: dict[str, Any] = {}
event_type: str
try:
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if (
not payload
or payload is PayloadSentinel.DEFAULT
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
):
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
try:
event_attributes = json_loads_object(payload)
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
_LOGGER.debug(
(
"JSON event data detected after processing payload '%s' on"
" topic %s, type %s, attributes %s"
),
payload,
msg.topic,
event_type,
event_attributes,
)
except KeyError:
_LOGGER.warning(
("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
payload,
msg.topic,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid JSON event payload detected, "
"value after processing payload"
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
try:
self._trigger_event(event_type, event_attributes)
except ValueError:
_LOGGER.warning(
"Invalid event type %s for %s received on topic %s, payload %s",
event_type,
self.entity_id,
msg.topic,
payload,
)
return
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {} topics: dict[str, dict[str, Any]] = {}
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if msg.retain:
_LOGGER.debug(
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
msg.payload,
msg.topic,
)
return
event_attributes: dict[str, Any] = {}
event_type: str
try:
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if (
not payload
or payload is PayloadSentinel.DEFAULT
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
):
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
try:
event_attributes = json_loads_object(payload)
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
_LOGGER.debug(
(
"JSON event data detected after processing payload '%s' on"
" topic %s, type %s, attributes %s"
),
payload,
msg.topic,
event_type,
event_attributes,
)
except KeyError:
_LOGGER.warning(
(
"`event_type` missing in JSON event payload, "
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid JSON event payload detected, "
"value after processing payload"
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
try:
self._trigger_event(event_type, event_attributes)
except ValueError:
_LOGGER.warning(
"Invalid event type %s for %s received on topic %s, payload %s",
event_type,
self.entity_id,
msg.topic,
payload,
)
return
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._event_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from base64 import b64decode from base64 import b64decode
import binascii import binascii
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_ENCODING, CONF_QOS from .const import CONF_ENCODING, CONF_QOS
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@ -143,6 +143,45 @@ class MqttImage(MqttEntity, ImageEntity):
config.get(CONF_URL_TEMPLATE), entity=self config.get(CONF_URL_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _image_data_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err:
_LOGGER.error(
"Error processing image data received at topic %s: %s",
msg.topic,
err,
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
@callback
def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
url = cv.url(self._url_template(msg.payload))
self._attr_image_url = url
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
except vol.Invalid:
_LOGGER.error(
"Invalid image URL '%s' received at topic %s",
msg.payload,
msg.topic,
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@ -159,56 +198,15 @@ class MqttImage(MqttEntity, ImageEntity):
if has_topic := self._topic[topic] is not None: if has_topic := self._topic[topic] is not None:
topics[topic] = { topics[topic] = {
"topic": self._topic[topic], "topic": self._topic[topic],
"msg_callback": msg_callback, "msg_callback": partial(self._message_callback, msg_callback, None),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": encoding, "encoding": encoding,
} }
return has_topic return has_topic
@callback add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
@log_messages(self.hass, self.entity_id) add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
def image_data_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err:
_LOGGER.error(
"Error processing image data received at topic %s: %s",
msg.topic,
err,
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@callback
@log_messages(self.hass, self.entity_id)
def image_from_url_request_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
url = cv.url(self._url_template(msg.payload))
self._attr_image_url = url
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
except vol.Invalid:
_LOGGER.error(
"Invalid image URL '%s' received at topic %s",
msg.payload,
msg.topic,
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics

View File

@ -1254,13 +1254,15 @@ class MqttEntity(
def _message_callback( def _message_callback(
self, self,
msg_callback: MessageCallbackType, msg_callback: MessageCallbackType,
attributes: set[str], attributes: set[str] | None,
msg: ReceiveMessage, msg: ReceiveMessage,
) -> None: ) -> None:
"""Process the message callback.""" """Process the message callback."""
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( if attributes is not None:
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
) (attribute, getattr(self, attribute, UNDEFINED))
for attribute in attributes
)
mqtt_data = self.hass.data[DATA_MQTT] mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][ messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
msg.subscribed_topic msg.subscribed_topic
@ -1274,7 +1276,7 @@ class MqttEntity(
_LOGGER.warning(exc) _LOGGER.warning(exc)
return return
if self._attrs_have_changed(attrs_snapshot): if attributes is not None and self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)