diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index 63c4a79b96f..eaea908e358 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -173,7 +173,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): self._config[CONF_COMMAND_TEMPLATE], entity=self ).async_render - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -198,7 +198,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): self._state = payload self.async_write_ha_state() - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -211,6 +211,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def state(self): """Return the state of the device.""" diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index c500d52dd70..b84ddaad404 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -164,7 +164,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): entity=self, ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -241,7 +241,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): self.async_write_ha_state() - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -254,6 +254,10 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @callback def _value_is_expired(self, *_): """Triggered when value is expired.""" diff --git a/homeassistant/components/mqtt/button.py b/homeassistant/components/mqtt/button.py index 7143b65ed9e..993251606ed 100644 --- a/homeassistant/components/mqtt/button.py +++ b/homeassistant/components/mqtt/button.py @@ -95,6 +95,9 @@ class MqttButton(MqttEntity, ButtonEntity): config.get(CONF_COMMAND_TEMPLATE), entity=self ).async_render + def _prepare_subscribe_topics(self): + """(Re)Subscribe to topics.""" + async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/components/mqtt/camera.py b/homeassistant/components/mqtt/camera.py index 5c2b8258f01..1c060f7f32a 100644 --- a/homeassistant/components/mqtt/camera.py +++ b/homeassistant/components/mqtt/camera.py @@ -90,7 +90,7 @@ class MqttCamera(MqttEntity, Camera): """Return the config schema.""" return DISCOVERY_SCHEMA - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -99,7 +99,7 @@ class MqttCamera(MqttEntity, Camera): """Handle new MQTT messages.""" self._last_image = msg.payload - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -112,6 +112,10 @@ class MqttCamera(MqttEntity, Camera): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + async def async_camera_image( self, width: int | None = None, height: int | None = None ) -> bytes | None: diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index 2e19a345bc3..02d4f267fe8 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -358,11 +358,6 @@ class MqttClimate(MqttEntity, ClimateEntity): """Return the config schema.""" return DISCOVERY_SCHEMA - async def async_added_to_hass(self): - """Handle being added to Home Assistant.""" - await super().async_added_to_hass() - await self._subscribe_topics() - def _setup_from_config(self, config): """(Re)Setup the entity.""" self._topic = {key: config.get(key) for key in TOPIC_KEYS} @@ -417,7 +412,7 @@ class MqttClimate(MqttEntity, ClimateEntity): self._command_templates = command_templates - async def _subscribe_topics(self): # noqa: C901 + def _prepare_subscribe_topics(self): # noqa: C901 """(Re)Subscribe to topics.""" topics = {} qos = self._config[CONF_QOS] @@ -615,10 +610,14 @@ class MqttClimate(MqttEntity, ClimateEntity): add_subscription(topics, CONF_HOLD_STATE_TOPIC, handle_hold_mode_received) - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def temperature_unit(self): """Return the unit of measurement.""" diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index 95ea6182bf1..dfb48fb89e2 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -335,7 +335,7 @@ class MqttCover(MqttEntity, CoverEntity): config_attributes=template_config_attributes, ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" topics = {} @@ -460,10 +460,14 @@ class MqttCover(MqttEntity, CoverEntity): "encoding": self._config[CONF_ENCODING] or None, } - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def assumed_state(self): """Return true if we do optimistic updates.""" diff --git a/homeassistant/components/mqtt/device_tracker/schema_discovery.py b/homeassistant/components/mqtt/device_tracker/schema_discovery.py index 3ee5f22be90..a7b597d0689 100644 --- a/homeassistant/components/mqtt/device_tracker/schema_discovery.py +++ b/homeassistant/components/mqtt/device_tracker/schema_discovery.py @@ -77,7 +77,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): self._config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -94,7 +94,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): self.async_write_ha_state() - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -106,6 +106,10 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def latitude(self): """Return latitude if provided in extra_state_attributes or None.""" diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index fb6d21c8538..f6d60ae8cbe 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -351,7 +351,7 @@ class MqttFan(MqttEntity, FanEntity): entity=self, ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" topics = {} @@ -479,10 +479,14 @@ class MqttFan(MqttEntity, FanEntity): } self._oscillation = False - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def assumed_state(self): """Return true if we do optimistic updates.""" diff --git a/homeassistant/components/mqtt/humidifier.py b/homeassistant/components/mqtt/humidifier.py index df1b7667ef7..b2c4ed4b916 100644 --- a/homeassistant/components/mqtt/humidifier.py +++ b/homeassistant/components/mqtt/humidifier.py @@ -267,7 +267,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): entity=self, ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" topics = {} @@ -373,10 +373,14 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): } self._mode = None - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def assumed_state(self): """Return true if we do optimistic updates.""" diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index 4f692c8063c..f164abe5297 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -417,12 +417,10 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): """Return True if the attribute is optimistically updated.""" return getattr(self, f"_optimistic_{attribute}") - async def _subscribe_topics(self): # noqa: C901 + def _prepare_subscribe_topics(self): # noqa: C901 """(Re)Subscribe to topics.""" topics = {} - last_state = await self.async_get_last_state() - def add_topic(topic, msg_callback): """Add a topic.""" if self._topic[topic] is not None: @@ -433,14 +431,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): "encoding": self._config[CONF_ENCODING] or None, } - def restore_state(attribute, condition_attribute=None): - """Restore a state attribute.""" - if condition_attribute is None: - condition_attribute = attribute - optimistic = self._is_optimistic(condition_attribute) - if optimistic and last_state and last_state.attributes.get(attribute): - setattr(self, f"_{attribute}", last_state.attributes[attribute]) - @callback @log_messages(self.hass, self.entity_id) def state_received(msg): @@ -465,8 +455,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } - elif self._optimistic and last_state: - self._state = last_state.state == STATE_ON @callback @log_messages(self.hass, self.entity_id) @@ -485,7 +473,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received) - restore_state(ATTR_BRIGHTNESS) def _rgbx_received(msg, template, color_mode, convert_color): """Handle new MQTT messages for RGBW and RGBWW.""" @@ -520,8 +507,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_RGB_STATE_TOPIC, rgb_received) - restore_state(ATTR_RGB_COLOR) - restore_state(ATTR_HS_COLOR, ATTR_RGB_COLOR) @callback @log_messages(self.hass, self.entity_id) @@ -539,7 +524,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received) - restore_state(ATTR_RGBW_COLOR) @callback @log_messages(self.hass, self.entity_id) @@ -557,7 +541,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received) - restore_state(ATTR_RGBWW_COLOR) @callback @log_messages(self.hass, self.entity_id) @@ -574,7 +557,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received) - restore_state(ATTR_COLOR_MODE) @callback @log_messages(self.hass, self.entity_id) @@ -593,7 +575,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received) - restore_state(ATTR_COLOR_TEMP) @callback @log_messages(self.hass, self.entity_id) @@ -610,7 +591,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_EFFECT_STATE_TOPIC, effect_received) - restore_state(ATTR_EFFECT) @callback @log_messages(self.hass, self.entity_id) @@ -630,7 +610,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): _LOGGER.debug("Failed to parse hs state update: '%s'", payload) add_topic(CONF_HS_STATE_TOPIC, hs_received) - restore_state(ATTR_HS_COLOR) @callback @log_messages(self.hass, self.entity_id) @@ -649,7 +628,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_WHITE_VALUE_STATE_TOPIC, white_value_received) - restore_state(ATTR_WHITE_VALUE) @callback @log_messages(self.hass, self.entity_id) @@ -670,13 +648,39 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() add_topic(CONF_XY_STATE_TOPIC, xy_received) - restore_state(ATTR_XY_COLOR) - restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR) - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + last_state = await self.async_get_last_state() + + def restore_state(attribute, condition_attribute=None): + """Restore a state attribute.""" + if condition_attribute is None: + condition_attribute = attribute + optimistic = self._is_optimistic(condition_attribute) + if optimistic and last_state and last_state.attributes.get(attribute): + setattr(self, f"_{attribute}", last_state.attributes[attribute]) + + if self._topic[CONF_STATE_TOPIC] is None and self._optimistic and last_state: + self._state = last_state.state == STATE_ON + restore_state(ATTR_BRIGHTNESS) + restore_state(ATTR_RGB_COLOR) + restore_state(ATTR_HS_COLOR, ATTR_RGB_COLOR) + restore_state(ATTR_RGBW_COLOR) + restore_state(ATTR_RGBWW_COLOR) + restore_state(ATTR_COLOR_MODE) + restore_state(ATTR_COLOR_TEMP) + restore_state(ATTR_EFFECT) + restore_state(ATTR_HS_COLOR) + restore_state(ATTR_WHITE_VALUE) + restore_state(ATTR_XY_COLOR) + restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR) + @property def brightness(self): """Return the brightness of this light between 0..255.""" diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index 3adaec38adb..2d9b8f6a388 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -304,9 +304,8 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): except (KeyError, ValueError): _LOGGER.warning("Invalid or incomplete color value received") - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" - last_state = await self.async_get_last_state() @callback @log_messages(self.hass, self.entity_id) @@ -370,7 +369,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() if self._topic[CONF_STATE_TOPIC] is not None: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -383,6 +382,11 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + + last_state = await self.async_get_last_state() if self._optimistic and last_state: self._state = last_state.state == STATE_ON last_attributes = last_state.attributes diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py index 54252ebc0b0..736ff98f321 100644 --- a/homeassistant/components/mqtt/light/schema_template.py +++ b/homeassistant/components/mqtt/light/schema_template.py @@ -156,14 +156,12 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity): or self._templates[CONF_STATE_TEMPLATE] is None ) - async def _subscribe_topics(self): # noqa: C901 + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" for tpl in self._templates.values(): if tpl is not None: tpl = MqttValueTemplate(tpl, entity=self) - last_state = await self.async_get_last_state() - @callback @log_messages(self.hass, self.entity_id) def state_received(msg): @@ -246,7 +244,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity): self.async_write_ha_state() if self._topics[CONF_STATE_TOPIC] is not None: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -259,6 +257,11 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + + last_state = await self.async_get_last_state() if self._optimistic and last_state: self._state = last_state.state == STATE_ON if last_state.attributes.get(ATTR_BRIGHTNESS): diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index 1c280405522..89917f4cc5c 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -123,7 +123,7 @@ class MqttLock(MqttEntity, LockEntity): entity=self, ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -142,7 +142,7 @@ class MqttLock(MqttEntity, LockEntity): # Force into optimistic mode. self._optimistic = True else: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -155,6 +155,10 @@ class MqttLock(MqttEntity, LockEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def is_locked(self): """Return true if lock is locked.""" diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index f49af86360d..1f4bbe9d949 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -66,7 +66,11 @@ from .discovery import ( set_discovery_hash, ) from .models import ReceiveMessage -from .subscription import async_subscribe_topics, async_unsubscribe_topics +from .subscription import ( + async_prepare_subscribe_topics, + async_subscribe_topics, + async_unsubscribe_topics, +) from .util import valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -249,14 +253,19 @@ class MqttAttributes(Entity): async def async_added_to_hass(self) -> None: """Subscribe MQTT events.""" await super().async_added_to_hass() + self._attributes_prepare_subscribe_topics() await self._attributes_subscribe_topics() + def attributes_prepare_discovery_update(self, config: dict): + """Handle updated discovery message.""" + self._attributes_config = config + self._attributes_prepare_subscribe_topics() + async def attributes_discovery_update(self, config: dict): """Handle updated discovery message.""" - self._attributes_config = config await self._attributes_subscribe_topics() - async def _attributes_subscribe_topics(self): + def _attributes_prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" attr_tpl = MqttValueTemplate( self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self @@ -284,7 +293,7 @@ class MqttAttributes(Entity): _LOGGER.warning("Erroneous JSON: %s", payload) self._attributes = None - self._attributes_sub_state = await async_subscribe_topics( + self._attributes_sub_state = async_prepare_subscribe_topics( self.hass, self._attributes_sub_state, { @@ -297,9 +306,13 @@ class MqttAttributes(Entity): }, ) + async def _attributes_subscribe_topics(self): + """(Re)Subscribe to topics.""" + await async_subscribe_topics(self.hass, self._attributes_sub_state) + async def async_will_remove_from_hass(self): """Unsubscribe when removed.""" - self._attributes_sub_state = await async_unsubscribe_topics( + self._attributes_sub_state = async_unsubscribe_topics( self.hass, self._attributes_sub_state ) @@ -322,6 +335,7 @@ class MqttAvailability(Entity): async def async_added_to_hass(self) -> None: """Subscribe MQTT events.""" await super().async_added_to_hass() + self._availability_prepare_subscribe_topics() await self._availability_subscribe_topics() self.async_on_remove( async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect) @@ -332,9 +346,13 @@ class MqttAvailability(Entity): ) ) - async def availability_discovery_update(self, config: dict): + def availability_prepare_discovery_update(self, config: dict): """Handle updated discovery message.""" self._availability_setup_from_config(config) + self._availability_prepare_subscribe_topics() + + async def availability_discovery_update(self, config: dict): + """Handle updated discovery message.""" await self._availability_subscribe_topics() def _availability_setup_from_config(self, config): @@ -366,7 +384,7 @@ class MqttAvailability(Entity): self._avail_config = config - async def _availability_subscribe_topics(self): + def _availability_prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -398,12 +416,16 @@ class MqttAvailability(Entity): for topic in self._avail_topics } - self._availability_sub_state = await async_subscribe_topics( + self._availability_sub_state = async_prepare_subscribe_topics( self.hass, self._availability_sub_state, topics, ) + async def _availability_subscribe_topics(self): + """(Re)Subscribe to topics.""" + await async_subscribe_topics(self.hass, self._availability_sub_state) + @callback def async_mqtt_connect(self): """Update state on connection/disconnection to MQTT broker.""" @@ -412,7 +434,7 @@ class MqttAvailability(Entity): async def async_will_remove_from_hass(self): """Unsubscribe when removed.""" - self._availability_sub_state = await async_unsubscribe_topics( + self._availability_sub_state = async_unsubscribe_topics( self.hass, self._availability_sub_state ) @@ -601,7 +623,7 @@ class MqttEntityDeviceInfo(Entity): self._device_config = device_config self._config_entry = config_entry - async def device_info_discovery_update(self, config: dict): + def device_info_discovery_update(self, config: dict): """Handle updated discovery message.""" self._device_config = config.get(CONF_DEVICE) device_registry = dr.async_get(self.hass) @@ -657,6 +679,7 @@ class MqttEntity( async def async_added_to_hass(self): """Subscribe mqtt events.""" await super().async_added_to_hass() + self._prepare_subscribe_topics() await self._subscribe_topics() async def discovery_update(self, discovery_payload): @@ -664,15 +687,22 @@ class MqttEntity( config = self.config_schema()(discovery_payload) self._config = config self._setup_from_config(self._config) + + # Prepare MQTT subscriptions + self.attributes_prepare_discovery_update(config) + self.availability_prepare_discovery_update(config) + self.device_info_discovery_update(config) + self._prepare_subscribe_topics() + + # Finalize MQTT subscriptions await self.attributes_discovery_update(config) await self.availability_discovery_update(config) - await self.device_info_discovery_update(config) await self._subscribe_topics() self.async_write_ha_state() async def async_will_remove_from_hass(self): """Unsubscribe when removed.""" - self._sub_state = await subscription.async_unsubscribe_topics( + self._sub_state = subscription.async_unsubscribe_topics( self.hass, self._sub_state ) await MqttAttributes.async_will_remove_from_hass(self) @@ -687,6 +717,10 @@ class MqttEntity( def _setup_from_config(self, config): """(Re)Setup the entity.""" + @abstractmethod + def _prepare_subscribe_topics(self): + """(Re)Subscribe to topics.""" + @abstractmethod async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index 0020a020411..511b6e470ac 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -162,7 +162,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): ).async_render_with_possible_json_value, } - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -200,7 +200,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): # Force into optimistic mode. self._optimistic = True else: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -213,6 +213,10 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + if self._optimistic and (last_state := await self.async_get_last_state()): self._current_number = last_state.state diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index 810c6126d52..24bed158eda 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -128,7 +128,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): ).async_render_with_possible_json_value, } - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -156,7 +156,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): # Force into optimistic mode. self._optimistic = True else: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -169,6 +169,10 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + if self._optimistic and (last_state := await self.async_get_last_state()): self._attr_current_option = last_state.state diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 59f124155d3..6cf72279546 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -213,7 +213,7 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity): self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" topics = {} @@ -304,10 +304,14 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity): "encoding": self._config[CONF_ENCODING] or None, } - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @callback def _value_is_expired(self, *_): """Triggered when value is expired.""" diff --git a/homeassistant/components/mqtt/siren.py b/homeassistant/components/mqtt/siren.py index 6a268881593..e33a13545b3 100644 --- a/homeassistant/components/mqtt/siren.py +++ b/homeassistant/components/mqtt/siren.py @@ -214,7 +214,7 @@ class MqttSiren(MqttEntity, SirenEntity): entity=self, ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -274,7 +274,7 @@ class MqttSiren(MqttEntity, SirenEntity): # Force into optimistic mode. self._optimistic = True else: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -287,6 +287,10 @@ class MqttSiren(MqttEntity, SirenEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def assumed_state(self): """Return true if we do optimistic updates.""" diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index 6d132b28a98..d0af533f294 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -1,13 +1,12 @@ """Helper to handle a set of topics to subscribe to.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Coroutine from typing import Any import attr from homeassistant.core import HomeAssistant -from homeassistant.loader import bind_hass from . import debug_info from .. import mqtt @@ -22,11 +21,12 @@ class EntitySubscription: hass: HomeAssistant = attr.ib() topic: str = attr.ib() message_callback: MessageCallbackType = attr.ib() + subscribe_task: Coroutine | None = attr.ib() unsubscribe_callback: Callable[[], None] | None = attr.ib() qos: int = attr.ib(default=0) encoding: str = attr.ib(default="utf-8") - async def resubscribe_if_necessary(self, hass, other): + def resubscribe_if_necessary(self, hass, other): """Re-subscribe to the new topic if necessary.""" if not self._should_resubscribe(other): self.unsubscribe_callback = other.unsubscribe_callback @@ -46,33 +46,41 @@ class EntitySubscription: # Prepare debug data debug_info.add_subscription(self.hass, self.message_callback, self.topic) - self.unsubscribe_callback = await mqtt.async_subscribe( + self.subscribe_task = mqtt.async_subscribe( hass, self.topic, self.message_callback, self.qos, self.encoding ) + async def subscribe(self): + """Subscribe to a topic.""" + if not self.subscribe_task: + return + self.unsubscribe_callback = await self.subscribe_task + def _should_resubscribe(self, other): """Check if we should re-subscribe to the topic using the old state.""" if other is None: return True - return (self.topic, self.qos, self.encoding) != ( + return (self.topic, self.qos, self.encoding,) != ( other.topic, other.qos, other.encoding, ) -@bind_hass -async def async_subscribe_topics( +def async_prepare_subscribe_topics( hass: HomeAssistant, new_state: dict[str, EntitySubscription] | None, topics: dict[str, Any], ) -> dict[str, EntitySubscription]: - """(Re)Subscribe to a set of MQTT topics. + """Prepare (re)subscribe to a set of MQTT topics. State is kept in sub_state and a dictionary mapping from the subscription key to the subscription state. + After this function has been called, async_subscribe_topics must be called to + finalize any new subscriptions. + Please note that the sub state must not be shared between multiple sets of topics. Every call to async_subscribe_topics must always contain _all_ the topics the subscription state should manage. @@ -88,10 +96,11 @@ async def async_subscribe_topics( qos=value.get("qos", DEFAULT_QOS), encoding=value.get("encoding", "utf-8"), hass=hass, + subscribe_task=None, ) # Get the current subscription state current = current_subscriptions.pop(key, None) - await requested.resubscribe_if_necessary(hass, current) + requested.resubscribe_if_necessary(hass, current) new_state[key] = requested # Go through all remaining subscriptions and unsubscribe them @@ -106,9 +115,19 @@ async def async_subscribe_topics( return new_state -@bind_hass -async def async_unsubscribe_topics( +async def async_subscribe_topics( + hass: HomeAssistant, + sub_state: dict[str, EntitySubscription] | None, +) -> None: + """(Re)Subscribe to a set of MQTT topics.""" + if sub_state is None: + return + for sub in sub_state.values(): + await sub.subscribe() + + +def async_unsubscribe_topics( hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None ) -> dict[str, EntitySubscription]: """Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" - return await async_subscribe_topics(hass, sub_state, {}) + return async_prepare_subscribe_topics(hass, sub_state, {}) diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index 9feba2c1d25..6576072e407 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -132,7 +132,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): self._config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" @callback @@ -151,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): # Force into optimistic mode. self._optimistic = True else: - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -164,6 +164,10 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + if self._optimistic and (last_state := await self.async_get_last_state()): self._state = last_state.state == STATE_ON diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index b2638f8ac4b..e4152250802 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -175,7 +175,7 @@ class MQTTTagScanner: await self.hass.components.tag.async_scan_tag(tag_id, self.device_id) - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -186,6 +186,7 @@ class MQTTTagScanner: } }, ) + await subscription.async_subscribe_topics(self.hass, self._sub_state) async def device_removed(self, event): """Handle the removal of a device.""" @@ -207,7 +208,7 @@ class MQTTTagScanner: self._remove_discovery() mqtt.publish(self.hass, discovery_topic, "", retain=True) - self._sub_state = await subscription.async_unsubscribe_topics( + self._sub_state = subscription.async_unsubscribe_topics( self.hass, self._sub_state ) if self.device_id: diff --git a/homeassistant/components/mqtt/vacuum/schema_legacy.py b/homeassistant/components/mqtt/vacuum/schema_legacy.py index 5f85acc75ca..3a764ca9e45 100644 --- a/homeassistant/components/mqtt/vacuum/schema_legacy.py +++ b/homeassistant/components/mqtt/vacuum/schema_legacy.py @@ -240,7 +240,7 @@ class MqttVacuum(MqttEntity, VacuumEntity): ) } - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" for tpl in self._templates.values(): if tpl is not None: @@ -325,7 +325,7 @@ class MqttVacuum(MqttEntity, VacuumEntity): self.async_write_ha_state() topics_list = {topic for topic in self._state_topics.values() if topic} - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { @@ -339,6 +339,10 @@ class MqttVacuum(MqttEntity, VacuumEntity): }, ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def is_on(self): """Return true if vacuum is on.""" diff --git a/homeassistant/components/mqtt/vacuum/schema_state.py b/homeassistant/components/mqtt/vacuum/schema_state.py index 872f4d62765..1eb763c76cd 100644 --- a/homeassistant/components/mqtt/vacuum/schema_state.py +++ b/homeassistant/components/mqtt/vacuum/schema_state.py @@ -197,7 +197,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity): ) } - async def _subscribe_topics(self): + def _prepare_subscribe_topics(self): """(Re)Subscribe to topics.""" topics = {} @@ -219,10 +219,14 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity): "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } - self._sub_state = await subscription.async_subscribe_topics( + self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics ) + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + await subscription.async_subscribe_topics(self.hass, self._sub_state) + @property def state(self): """Return state of vacuum.""" diff --git a/homeassistant/components/tasmota/__init__.py b/homeassistant/components/tasmota/__init__.py index f8dcd4035df..2d664bb46ee 100644 --- a/homeassistant/components/tasmota/__init__.py +++ b/homeassistant/components/tasmota/__init__.py @@ -19,6 +19,7 @@ import voluptuous as vol from homeassistant.components import mqtt, websocket_api from homeassistant.components.mqtt.subscription import ( + async_prepare_subscribe_topics, async_subscribe_topics, async_unsubscribe_topics, ) @@ -62,10 +63,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: for topic in topics.values(): if "msg_callback" in topic and "event_loop_safe" in topic: topic["msg_callback"] = callback(topic["msg_callback"]) - return await async_subscribe_topics(hass, sub_state, topics) + sub_state = async_prepare_subscribe_topics(hass, sub_state, topics) + await async_subscribe_topics(hass, sub_state) + return sub_state async def _unsubscribe_topics(sub_state: dict | None) -> dict: - return await async_unsubscribe_topics(hass, sub_state) + return async_unsubscribe_topics(hass, sub_state) tasmota_mqtt = TasmotaMQTTClient(_publish, _subscribe_topics, _unsubscribe_topics) diff --git a/tests/components/mqtt/test_subscription.py b/tests/components/mqtt/test_subscription.py index 36d8946be0b..e2ffc602ddd 100644 --- a/tests/components/mqtt/test_subscription.py +++ b/tests/components/mqtt/test_subscription.py @@ -2,6 +2,7 @@ from unittest.mock import ANY from homeassistant.components.mqtt.subscription import ( + async_prepare_subscribe_topics, async_subscribe_topics, async_unsubscribe_topics, ) @@ -27,7 +28,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog): calls2.append(args) sub_state = None - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, { @@ -35,6 +36,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog): "test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2}, }, ) + await async_subscribe_topics(hass, sub_state) async_fire_mqtt_message(hass, "test-topic1", "test-payload1") assert len(calls1) == 1 @@ -48,7 +50,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog): assert calls2[0][0].topic == "test-topic2" assert calls2[0][0].payload == "test-payload2" - await async_unsubscribe_topics(hass, sub_state) + async_unsubscribe_topics(hass, sub_state) async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic2", "test-payload") @@ -74,7 +76,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog): calls2.append(args) sub_state = None - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, { @@ -82,6 +84,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog): "test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2}, }, ) + await async_subscribe_topics(hass, sub_state) async_fire_mqtt_message(hass, "test-topic1", "test-payload") assert len(calls1) == 1 @@ -91,11 +94,12 @@ async def test_modify_topics(hass, mqtt_mock, caplog): assert len(calls1) == 1 assert len(calls2) == 1 - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, {"test_topic1": {"topic": "test-topic1_1", "msg_callback": record_calls1}}, ) + await async_subscribe_topics(hass, sub_state) async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic2", "test-payload") @@ -108,7 +112,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog): assert calls1[1][0].payload == "test-payload" assert len(calls2) == 1 - await async_unsubscribe_topics(hass, sub_state) + async_unsubscribe_topics(hass, sub_state) async_fire_mqtt_message(hass, "test-topic1_1", "test-payload") async_fire_mqtt_message(hass, "test-topic2", "test-payload") @@ -126,11 +130,12 @@ async def test_qos_encoding_default(hass, mqtt_mock, caplog): pass sub_state = None - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, {"test_topic1": {"topic": "test-topic1", "msg_callback": msg_callback}}, ) + await async_subscribe_topics(hass, sub_state) mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 0, "utf-8") @@ -143,7 +148,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog): pass sub_state = None - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, { @@ -155,6 +160,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog): } }, ) + await async_subscribe_topics(hass, sub_state) mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 1, "utf-16") @@ -169,27 +175,29 @@ async def test_no_change(hass, mqtt_mock, caplog): calls.append(args) sub_state = None - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, {"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}}, ) + await async_subscribe_topics(hass, sub_state) subscribe_call_count = mqtt_mock.async_subscribe.call_count async_fire_mqtt_message(hass, "test-topic1", "test-payload") assert len(calls) == 1 - sub_state = await async_subscribe_topics( + sub_state = async_prepare_subscribe_topics( hass, sub_state, {"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}}, ) + await async_subscribe_topics(hass, sub_state) assert subscribe_call_count == mqtt_mock.async_subscribe.call_count async_fire_mqtt_message(hass, "test-topic1", "test-payload") assert len(calls) == 2 - await async_unsubscribe_topics(hass, sub_state) + async_unsubscribe_topics(hass, sub_state) async_fire_mqtt_message(hass, "test-topic1", "test-payload") assert len(calls) == 2