diff --git a/homeassistant/components/mqtt/update.py b/homeassistant/components/mqtt/update.py index f6db0d3fd64..cf3237c1b1c 100644 --- a/homeassistant/components/mqtt/update.py +++ b/homeassistant/components/mqtt/update.py @@ -33,9 +33,14 @@ from .const import ( PAYLOAD_EMPTY_JSON, ) from .debug_info import log_messages -from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper +from .mixins import ( + MQTT_ENTITY_COMMON_SCHEMA, + MqttEntity, + async_setup_entry_helper, + write_state_on_attr_change, +) from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage -from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic +from .util import valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -171,6 +176,17 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change( + self, + { + "_attr_installed_version", + "_attr_latest_version", + "_attr_title", + "_attr_release_summary", + "_attr_release_url", + "_entity_picture", + }, + ) def handle_state_message_received(msg: ReceiveMessage) -> None: """Handle receiving state message via MQTT.""" payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) @@ -219,39 +235,33 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): if "installed_version" in json_payload: self._attr_installed_version = json_payload["installed_version"] - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if "latest_version" in json_payload: self._attr_latest_version = json_payload["latest_version"] - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if "title" in json_payload: self._attr_title = json_payload["title"] - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if "release_summary" in json_payload: self._attr_release_summary = json_payload["release_summary"] - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if "release_url" in json_payload: self._attr_release_url = json_payload["release_url"] - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if "entity_picture" in json_payload: self._entity_picture = json_payload["entity_picture"] - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received) @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"_attr_latest_version"}) def handle_latest_version_received(msg: ReceiveMessage) -> None: """Handle receiving latest version via MQTT.""" latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload) if isinstance(latest_version, str) and latest_version != "": self._attr_latest_version = latest_version - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) add_subscription( topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received @@ -279,8 +289,6 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): self._config[CONF_ENCODING], ) - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) - @property def supported_features(self) -> UpdateEntityFeature: """Return the list of supported features.""" diff --git a/tests/components/mqtt/test_update.py b/tests/components/mqtt/test_update.py index 9c881352f8c..c5fe5abd8c4 100644 --- a/tests/components/mqtt/test_update.py +++ b/tests/components/mqtt/test_update.py @@ -16,6 +16,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant from .test_common import ( + help_custom_config, help_test_availability_when_connection_lost, help_test_availability_without_topic, help_test_custom_availability_payload, @@ -33,6 +34,7 @@ from .test_common import ( help_test_reloadable, help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_with_template, + help_test_skipped_async_ha_write_state, help_test_unique_id, help_test_unload_config_entry_with_platform, help_test_update_with_json_attrs_bad_json, @@ -47,7 +49,7 @@ DEFAULT_CONFIG = { update.DOMAIN: { "name": "test", "state_topic": "test-topic", - "latest_version_topic": "test-topic", + "latest_version_topic": "latest-version-topic", "command_topic": "test-topic", "payload_install": "install", } @@ -730,3 +732,53 @@ async def test_reloadable( domain = update.DOMAIN config = DEFAULT_CONFIG await help_test_reloadable(hass, mqtt_client_mock, domain, config) + + +@pytest.mark.parametrize( + "hass_config", + [ + help_custom_config( + update.DOMAIN, + DEFAULT_CONFIG, + ( + { + "availability_topic": "availability-topic", + "json_attributes_topic": "json-attributes-topic", + }, + ), + ) + ], +) +@pytest.mark.parametrize( + ("topic", "payload1", "payload2"), + [ + ("latest-version-topic", "1.1", "1.2"), + ("test-topic", "1.1", "1.2"), + ("test-topic", '{"installed_version": "1.1"}', '{"installed_version": "1.2"}'), + ("test-topic", '{"latest_version": "1.1"}', '{"latest_version": "1.2"}'), + ("test-topic", '{"title": "Update"}', '{"title": "Patch"}'), + ("test-topic", '{"release_summary": "bla1"}', '{"release_summary": "bla2"}'), + ( + "test-topic", + '{"release_url": "https://example.com/update?r=1"}', + '{"release_url": "https://example.com/update?r=2"}', + ), + ( + "test-topic", + '{"entity_picture": "https://example.com/icon1.png"}', + '{"entity_picture": "https://example.com/icon2.png"}', + ), + ("availability-topic", "online", "offline"), + ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), + ], +) +async def test_skipped_async_ha_write_state( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + topic: str, + payload1: str, + payload2: str, +) -> None: + """Test a write state command is only called when there is change.""" + await mqtt_mock_entry() + await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)