From b504665b56dbbf19c616d0f21c32b1b1d5a4b1d9 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Fri, 21 Jul 2023 06:35:58 +0200 Subject: [PATCH] Do not override extra_state_attributes property for MqttEntity (#96890) --- homeassistant/components/mqtt/mixins.py | 10 +-------- homeassistant/components/mqtt/siren.py | 30 +++++++++++++++---------- tests/components/mqtt/test_common.py | 2 +- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 314800f33f2..57ec933cd58 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -342,7 +342,6 @@ class MqttAttributes(Entity): def __init__(self, config: ConfigType) -> None: """Initialize the JSON attributes mixin.""" - self._attributes: dict[str, Any] | None = None self._attributes_sub_state: dict[str, EntitySubscription] = {} self._attributes_config = config @@ -380,16 +379,14 @@ class MqttAttributes(Entity): if k not in MQTT_ATTRIBUTES_BLOCKED and k not in self._attributes_extra_blocked } - self._attributes = filtered_dict + self._attr_extra_state_attributes = filtered_dict get_mqtt_data(self.hass).state_write_requests.write_state_request( self ) else: _LOGGER.warning("JSON result was not a dictionary") - self._attributes = None except ValueError: _LOGGER.warning("Erroneous JSON: %s", payload) - self._attributes = None self._attributes_sub_state = async_prepare_subscribe_topics( self.hass, @@ -414,11 +411,6 @@ class MqttAttributes(Entity): self.hass, self._attributes_sub_state ) - @property - def extra_state_attributes(self) -> dict[str, Any] | None: - """Return the state attributes.""" - return self._attributes - class MqttAvailability(Entity): """Mixin used for platforms that report availability.""" diff --git a/homeassistant/components/mqtt/siren.py b/homeassistant/components/mqtt/siren.py index 4134dd97148..d30080f4647 100644 --- a/homeassistant/components/mqtt/siren.py +++ b/homeassistant/components/mqtt/siren.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Callable -import copy import functools import logging from typing import Any, cast @@ -141,6 +140,7 @@ class MqttSiren(MqttEntity, SirenEntity): _entity_id_format = ENTITY_ID_FORMAT _attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED + _extra_attributes: dict[str, Any] _command_templates: dict[ str, Callable[[PublishPayloadType, TemplateVarsType], PublishPayloadType] | None @@ -158,6 +158,7 @@ class MqttSiren(MqttEntity, SirenEntity): discovery_data: DiscoveryInfoType | None, ) -> None: """Initialize the MQTT siren.""" + self._extra_attributes: dict[str, Any] = {} MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod @@ -174,21 +175,21 @@ class MqttSiren(MqttEntity, SirenEntity): state_off: str | None = config.get(CONF_STATE_OFF) self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF] - self._attr_extra_state_attributes = {} + self._extra_attributes = {} _supported_features = SUPPORTED_BASE if config[CONF_SUPPORT_DURATION]: _supported_features |= SirenEntityFeature.DURATION - self._attr_extra_state_attributes[ATTR_DURATION] = None + self._extra_attributes[ATTR_DURATION] = None if config.get(CONF_AVAILABLE_TONES): _supported_features |= SirenEntityFeature.TONES self._attr_available_tones = config[CONF_AVAILABLE_TONES] - self._attr_extra_state_attributes[ATTR_TONE] = None + self._extra_attributes[ATTR_TONE] = None if config[CONF_SUPPORT_VOLUME_SET]: _supported_features |= SirenEntityFeature.VOLUME_SET - self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None + self._extra_attributes[ATTR_VOLUME_LEVEL] = None self._attr_supported_features = _supported_features self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config @@ -305,14 +306,19 @@ class MqttSiren(MqttEntity, SirenEntity): return self._optimistic @property - def extra_state_attributes(self) -> dict[str, Any]: + def extra_state_attributes(self) -> dict[str, Any] | None: """Return the state attributes.""" - mqtt_attributes = super().extra_state_attributes - attributes = ( - copy.deepcopy(mqtt_attributes) if mqtt_attributes is not None else {} + extra_attributes = ( + self._attr_extra_state_attributes + if hasattr(self, "_attr_extra_state_attributes") + else {} ) - attributes.update(self._attr_extra_state_attributes) - return attributes + if extra_attributes: + return ( + dict({*self._extra_attributes.items(), *extra_attributes.items()}) + or None + ) + return self._extra_attributes or None async def _async_publish( self, @@ -376,6 +382,6 @@ class MqttSiren(MqttEntity, SirenEntity): """Update the extra siren state attributes.""" for attribute, support in SUPPORTED_ATTRIBUTES.items(): if self._attr_supported_features & support and attribute in data: - self._attr_extra_state_attributes[attribute] = data[ + self._extra_attributes[attribute] = data[ attribute # type: ignore[literal-required] ] diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index cfd714725c4..fd760044f3c 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -682,7 +682,7 @@ async def help_test_discovery_update_attr( # Verify we are no longer subscribing to the old topic async_fire_mqtt_message(hass, "attr-topic1", '{ "val": "50" }') state = hass.states.get(f"{domain}.test") - assert state and state.attributes.get("val") == "100" + assert state and state.attributes.get("val") != "50" # Verify we are subscribing to the new topic async_fire_mqtt_message(hass, "attr-topic2", '{ "val": "75" }')