Do not override extra_state_attributes property for MqttEntity (#96890)

This commit is contained in:
Jan Bouwhuis 2023-07-21 06:35:58 +02:00 committed by GitHub
parent c067c52cf4
commit b504665b56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 22 deletions

View File

@ -342,7 +342,6 @@ class MqttAttributes(Entity):
def __init__(self, config: ConfigType) -> None: def __init__(self, config: ConfigType) -> None:
"""Initialize the JSON attributes mixin.""" """Initialize the JSON attributes mixin."""
self._attributes: dict[str, Any] | None = None
self._attributes_sub_state: dict[str, EntitySubscription] = {} self._attributes_sub_state: dict[str, EntitySubscription] = {}
self._attributes_config = config self._attributes_config = config
@ -380,16 +379,14 @@ class MqttAttributes(Entity):
if k not in MQTT_ATTRIBUTES_BLOCKED if k not in MQTT_ATTRIBUTES_BLOCKED
and k not in self._attributes_extra_blocked and k not in self._attributes_extra_blocked
} }
self._attributes = filtered_dict self._attr_extra_state_attributes = filtered_dict
get_mqtt_data(self.hass).state_write_requests.write_state_request( get_mqtt_data(self.hass).state_write_requests.write_state_request(
self self
) )
else: else:
_LOGGER.warning("JSON result was not a dictionary") _LOGGER.warning("JSON result was not a dictionary")
self._attributes = None
except ValueError: except ValueError:
_LOGGER.warning("Erroneous JSON: %s", payload) _LOGGER.warning("Erroneous JSON: %s", payload)
self._attributes = None
self._attributes_sub_state = async_prepare_subscribe_topics( self._attributes_sub_state = async_prepare_subscribe_topics(
self.hass, self.hass,
@ -414,11 +411,6 @@ class MqttAttributes(Entity):
self.hass, self._attributes_sub_state self.hass, self._attributes_sub_state
) )
@property
def extra_state_attributes(self) -> dict[str, Any] | None:
"""Return the state attributes."""
return self._attributes
class MqttAvailability(Entity): class MqttAvailability(Entity):
"""Mixin used for platforms that report availability.""" """Mixin used for platforms that report availability."""

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import copy
import functools import functools
import logging import logging
from typing import Any, cast from typing import Any, cast
@ -141,6 +140,7 @@ class MqttSiren(MqttEntity, SirenEntity):
_entity_id_format = ENTITY_ID_FORMAT _entity_id_format = ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED _attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED
_extra_attributes: dict[str, Any]
_command_templates: dict[ _command_templates: dict[
str, Callable[[PublishPayloadType, TemplateVarsType], PublishPayloadType] | None str, Callable[[PublishPayloadType, TemplateVarsType], PublishPayloadType] | None
@ -158,6 +158,7 @@ class MqttSiren(MqttEntity, SirenEntity):
discovery_data: DiscoveryInfoType | None, discovery_data: DiscoveryInfoType | None,
) -> None: ) -> None:
"""Initialize the MQTT siren.""" """Initialize the MQTT siren."""
self._extra_attributes: dict[str, Any] = {}
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
@ -174,21 +175,21 @@ class MqttSiren(MqttEntity, SirenEntity):
state_off: str | None = config.get(CONF_STATE_OFF) state_off: str | None = config.get(CONF_STATE_OFF)
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF] self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
self._attr_extra_state_attributes = {} self._extra_attributes = {}
_supported_features = SUPPORTED_BASE _supported_features = SUPPORTED_BASE
if config[CONF_SUPPORT_DURATION]: if config[CONF_SUPPORT_DURATION]:
_supported_features |= SirenEntityFeature.DURATION _supported_features |= SirenEntityFeature.DURATION
self._attr_extra_state_attributes[ATTR_DURATION] = None self._extra_attributes[ATTR_DURATION] = None
if config.get(CONF_AVAILABLE_TONES): if config.get(CONF_AVAILABLE_TONES):
_supported_features |= SirenEntityFeature.TONES _supported_features |= SirenEntityFeature.TONES
self._attr_available_tones = config[CONF_AVAILABLE_TONES] self._attr_available_tones = config[CONF_AVAILABLE_TONES]
self._attr_extra_state_attributes[ATTR_TONE] = None self._extra_attributes[ATTR_TONE] = None
if config[CONF_SUPPORT_VOLUME_SET]: if config[CONF_SUPPORT_VOLUME_SET]:
_supported_features |= SirenEntityFeature.VOLUME_SET _supported_features |= SirenEntityFeature.VOLUME_SET
self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None self._extra_attributes[ATTR_VOLUME_LEVEL] = None
self._attr_supported_features = _supported_features self._attr_supported_features = _supported_features
self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config
@ -305,14 +306,19 @@ class MqttSiren(MqttEntity, SirenEntity):
return self._optimistic return self._optimistic
@property @property
def extra_state_attributes(self) -> dict[str, Any]: def extra_state_attributes(self) -> dict[str, Any] | None:
"""Return the state attributes.""" """Return the state attributes."""
mqtt_attributes = super().extra_state_attributes extra_attributes = (
attributes = ( self._attr_extra_state_attributes
copy.deepcopy(mqtt_attributes) if mqtt_attributes is not None else {} if hasattr(self, "_attr_extra_state_attributes")
else {}
) )
attributes.update(self._attr_extra_state_attributes) if extra_attributes:
return attributes return (
dict({*self._extra_attributes.items(), *extra_attributes.items()})
or None
)
return self._extra_attributes or None
async def _async_publish( async def _async_publish(
self, self,
@ -376,6 +382,6 @@ class MqttSiren(MqttEntity, SirenEntity):
"""Update the extra siren state attributes.""" """Update the extra siren state attributes."""
for attribute, support in SUPPORTED_ATTRIBUTES.items(): for attribute, support in SUPPORTED_ATTRIBUTES.items():
if self._attr_supported_features & support and attribute in data: if self._attr_supported_features & support and attribute in data:
self._attr_extra_state_attributes[attribute] = data[ self._extra_attributes[attribute] = data[
attribute # type: ignore[literal-required] attribute # type: ignore[literal-required]
] ]

View File

@ -682,7 +682,7 @@ async def help_test_discovery_update_attr(
# Verify we are no longer subscribing to the old topic # Verify we are no longer subscribing to the old topic
async_fire_mqtt_message(hass, "attr-topic1", '{ "val": "50" }') async_fire_mqtt_message(hass, "attr-topic1", '{ "val": "50" }')
state = hass.states.get(f"{domain}.test") state = hass.states.get(f"{domain}.test")
assert state and state.attributes.get("val") == "100" assert state and state.attributes.get("val") != "50"
# Verify we are subscribing to the new topic # Verify we are subscribing to the new topic
async_fire_mqtt_message(hass, "attr-topic2", '{ "val": "75" }') async_fire_mqtt_message(hass, "attr-topic2", '{ "val": "75" }')