Fix race when handling updated MQTT discovery data (#65415)

This commit is contained in:
Erik Montnemery 2022-02-03 02:12:22 +01:00 committed by GitHub
parent 4e7cf19b5f
commit f3a89de71f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 246 additions and 108 deletions

View File

@ -173,7 +173,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
self._config[CONF_COMMAND_TEMPLATE], entity=self
).async_render
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -198,7 +198,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
self._state = payload
self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -211,6 +211,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def state(self):
"""Return the state of the device."""

View File

@ -164,7 +164,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
entity=self,
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -241,7 +241,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -254,6 +254,10 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_):
"""Triggered when value is expired."""

View File

@ -95,6 +95,9 @@ class MqttButton(MqttEntity, ButtonEntity):
config.get(CONF_COMMAND_TEMPLATE), entity=self
).async_render
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View File

@ -90,7 +90,7 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema."""
return DISCOVERY_SCHEMA
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -99,7 +99,7 @@ class MqttCamera(MqttEntity, Camera):
"""Handle new MQTT messages."""
self._last_image = msg.payload
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -112,6 +112,10 @@ class MqttCamera(MqttEntity, Camera):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def async_camera_image(
self, width: int | None = None, height: int | None = None
) -> bytes | None:

View File

@ -358,11 +358,6 @@ class MqttClimate(MqttEntity, ClimateEntity):
"""Return the config schema."""
return DISCOVERY_SCHEMA
async def async_added_to_hass(self):
"""Handle being added to Home Assistant."""
await super().async_added_to_hass()
await self._subscribe_topics()
def _setup_from_config(self, config):
"""(Re)Setup the entity."""
self._topic = {key: config.get(key) for key in TOPIC_KEYS}
@ -417,7 +412,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
self._command_templates = command_templates
async def _subscribe_topics(self): # noqa: C901
def _prepare_subscribe_topics(self): # noqa: C901
"""(Re)Subscribe to topics."""
topics = {}
qos = self._config[CONF_QOS]
@ -615,10 +610,14 @@ class MqttClimate(MqttEntity, ClimateEntity):
add_subscription(topics, CONF_HOLD_STATE_TOPIC, handle_hold_mode_received)
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def temperature_unit(self):
"""Return the unit of measurement."""

View File

@ -335,7 +335,7 @@ class MqttCover(MqttEntity, CoverEntity):
config_attributes=template_config_attributes,
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
topics = {}
@ -460,10 +460,14 @@ class MqttCover(MqttEntity, CoverEntity):
"encoding": self._config[CONF_ENCODING] or None,
}
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def assumed_state(self):
"""Return true if we do optimistic updates."""

View File

@ -77,7 +77,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -94,7 +94,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -106,6 +106,10 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def latitude(self):
"""Return latitude if provided in extra_state_attributes or None."""

View File

@ -351,7 +351,7 @@ class MqttFan(MqttEntity, FanEntity):
entity=self,
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
topics = {}
@ -479,10 +479,14 @@ class MqttFan(MqttEntity, FanEntity):
}
self._oscillation = False
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def assumed_state(self):
"""Return true if we do optimistic updates."""

View File

@ -267,7 +267,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
entity=self,
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
topics = {}
@ -373,10 +373,14 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
}
self._mode = None
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def assumed_state(self):
"""Return true if we do optimistic updates."""

View File

@ -417,12 +417,10 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
"""Return True if the attribute is optimistically updated."""
return getattr(self, f"_optimistic_{attribute}")
async def _subscribe_topics(self): # noqa: C901
def _prepare_subscribe_topics(self): # noqa: C901
"""(Re)Subscribe to topics."""
topics = {}
last_state = await self.async_get_last_state()
def add_topic(topic, msg_callback):
"""Add a topic."""
if self._topic[topic] is not None:
@ -433,14 +431,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
"encoding": self._config[CONF_ENCODING] or None,
}
def restore_state(attribute, condition_attribute=None):
"""Restore a state attribute."""
if condition_attribute is None:
condition_attribute = attribute
optimistic = self._is_optimistic(condition_attribute)
if optimistic and last_state and last_state.attributes.get(attribute):
setattr(self, f"_{attribute}", last_state.attributes[attribute])
@callback
@log_messages(self.hass, self.entity_id)
def state_received(msg):
@ -465,8 +455,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
elif self._optimistic and last_state:
self._state = last_state.state == STATE_ON
@callback
@log_messages(self.hass, self.entity_id)
@ -485,7 +473,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
restore_state(ATTR_BRIGHTNESS)
def _rgbx_received(msg, template, color_mode, convert_color):
"""Handle new MQTT messages for RGBW and RGBWW."""
@ -520,8 +507,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
restore_state(ATTR_RGB_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_RGB_COLOR)
@callback
@log_messages(self.hass, self.entity_id)
@ -539,7 +524,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
restore_state(ATTR_RGBW_COLOR)
@callback
@log_messages(self.hass, self.entity_id)
@ -557,7 +541,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
restore_state(ATTR_RGBWW_COLOR)
@callback
@log_messages(self.hass, self.entity_id)
@ -574,7 +557,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
restore_state(ATTR_COLOR_MODE)
@callback
@log_messages(self.hass, self.entity_id)
@ -593,7 +575,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
restore_state(ATTR_COLOR_TEMP)
@callback
@log_messages(self.hass, self.entity_id)
@ -610,7 +591,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
restore_state(ATTR_EFFECT)
@callback
@log_messages(self.hass, self.entity_id)
@ -630,7 +610,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
_LOGGER.debug("Failed to parse hs state update: '%s'", payload)
add_topic(CONF_HS_STATE_TOPIC, hs_received)
restore_state(ATTR_HS_COLOR)
@callback
@log_messages(self.hass, self.entity_id)
@ -649,7 +628,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_WHITE_VALUE_STATE_TOPIC, white_value_received)
restore_state(ATTR_WHITE_VALUE)
@callback
@log_messages(self.hass, self.entity_id)
@ -670,13 +648,39 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
add_topic(CONF_XY_STATE_TOPIC, xy_received)
restore_state(ATTR_XY_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR)
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
def restore_state(attribute, condition_attribute=None):
"""Restore a state attribute."""
if condition_attribute is None:
condition_attribute = attribute
optimistic = self._is_optimistic(condition_attribute)
if optimistic and last_state and last_state.attributes.get(attribute):
setattr(self, f"_{attribute}", last_state.attributes[attribute])
if self._topic[CONF_STATE_TOPIC] is None and self._optimistic and last_state:
self._state = last_state.state == STATE_ON
restore_state(ATTR_BRIGHTNESS)
restore_state(ATTR_RGB_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_RGB_COLOR)
restore_state(ATTR_RGBW_COLOR)
restore_state(ATTR_RGBWW_COLOR)
restore_state(ATTR_COLOR_MODE)
restore_state(ATTR_COLOR_TEMP)
restore_state(ATTR_EFFECT)
restore_state(ATTR_HS_COLOR)
restore_state(ATTR_WHITE_VALUE)
restore_state(ATTR_XY_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR)
@property
def brightness(self):
"""Return the brightness of this light between 0..255."""

View File

@ -304,9 +304,8 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
except (KeyError, ValueError):
_LOGGER.warning("Invalid or incomplete color value received")
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
last_state = await self.async_get_last_state()
@callback
@log_messages(self.hass, self.entity_id)
@ -370,7 +369,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
if self._topic[CONF_STATE_TOPIC] is not None:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -383,6 +382,11 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state:
self._state = last_state.state == STATE_ON
last_attributes = last_state.attributes

View File

@ -156,14 +156,12 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
or self._templates[CONF_STATE_TEMPLATE] is None
)
async def _subscribe_topics(self): # noqa: C901
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
for tpl in self._templates.values():
if tpl is not None:
tpl = MqttValueTemplate(tpl, entity=self)
last_state = await self.async_get_last_state()
@callback
@log_messages(self.hass, self.entity_id)
def state_received(msg):
@ -246,7 +244,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state()
if self._topics[CONF_STATE_TOPIC] is not None:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -259,6 +257,11 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state:
self._state = last_state.state == STATE_ON
if last_state.attributes.get(ATTR_BRIGHTNESS):

View File

@ -123,7 +123,7 @@ class MqttLock(MqttEntity, LockEntity):
entity=self,
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -142,7 +142,7 @@ class MqttLock(MqttEntity, LockEntity):
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -155,6 +155,10 @@ class MqttLock(MqttEntity, LockEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def is_locked(self):
"""Return true if lock is locked."""

View File

@ -66,7 +66,11 @@ from .discovery import (
set_discovery_hash,
)
from .models import ReceiveMessage
from .subscription import async_subscribe_topics, async_unsubscribe_topics
from .subscription import (
async_prepare_subscribe_topics,
async_subscribe_topics,
async_unsubscribe_topics,
)
from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -249,14 +253,19 @@ class MqttAttributes(Entity):
async def async_added_to_hass(self) -> None:
"""Subscribe MQTT events."""
await super().async_added_to_hass()
self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: dict):
"""Handle updated discovery message."""
self._attributes_config = config
self._attributes_prepare_subscribe_topics()
async def attributes_discovery_update(self, config: dict):
"""Handle updated discovery message."""
self._attributes_config = config
await self._attributes_subscribe_topics()
async def _attributes_subscribe_topics(self):
def _attributes_prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
attr_tpl = MqttValueTemplate(
self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self
@ -284,7 +293,7 @@ class MqttAttributes(Entity):
_LOGGER.warning("Erroneous JSON: %s", payload)
self._attributes = None
self._attributes_sub_state = await async_subscribe_topics(
self._attributes_sub_state = async_prepare_subscribe_topics(
self.hass,
self._attributes_sub_state,
{
@ -297,9 +306,13 @@ class MqttAttributes(Entity):
},
)
async def _attributes_subscribe_topics(self):
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self):
"""Unsubscribe when removed."""
self._attributes_sub_state = await async_unsubscribe_topics(
self._attributes_sub_state = async_unsubscribe_topics(
self.hass, self._attributes_sub_state
)
@ -322,6 +335,7 @@ class MqttAvailability(Entity):
async def async_added_to_hass(self) -> None:
"""Subscribe MQTT events."""
await super().async_added_to_hass()
self._availability_prepare_subscribe_topics()
await self._availability_subscribe_topics()
self.async_on_remove(
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
@ -332,9 +346,13 @@ class MqttAvailability(Entity):
)
)
async def availability_discovery_update(self, config: dict):
def availability_prepare_discovery_update(self, config: dict):
"""Handle updated discovery message."""
self._availability_setup_from_config(config)
self._availability_prepare_subscribe_topics()
async def availability_discovery_update(self, config: dict):
"""Handle updated discovery message."""
await self._availability_subscribe_topics()
def _availability_setup_from_config(self, config):
@ -366,7 +384,7 @@ class MqttAvailability(Entity):
self._avail_config = config
async def _availability_subscribe_topics(self):
def _availability_prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -398,12 +416,16 @@ class MqttAvailability(Entity):
for topic in self._avail_topics
}
self._availability_sub_state = await async_subscribe_topics(
self._availability_sub_state = async_prepare_subscribe_topics(
self.hass,
self._availability_sub_state,
topics,
)
async def _availability_subscribe_topics(self):
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state)
@callback
def async_mqtt_connect(self):
"""Update state on connection/disconnection to MQTT broker."""
@ -412,7 +434,7 @@ class MqttAvailability(Entity):
async def async_will_remove_from_hass(self):
"""Unsubscribe when removed."""
self._availability_sub_state = await async_unsubscribe_topics(
self._availability_sub_state = async_unsubscribe_topics(
self.hass, self._availability_sub_state
)
@ -601,7 +623,7 @@ class MqttEntityDeviceInfo(Entity):
self._device_config = device_config
self._config_entry = config_entry
async def device_info_discovery_update(self, config: dict):
def device_info_discovery_update(self, config: dict):
"""Handle updated discovery message."""
self._device_config = config.get(CONF_DEVICE)
device_registry = dr.async_get(self.hass)
@ -657,6 +679,7 @@ class MqttEntity(
async def async_added_to_hass(self):
"""Subscribe mqtt events."""
await super().async_added_to_hass()
self._prepare_subscribe_topics()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):
@ -664,15 +687,22 @@ class MqttEntity(
config = self.config_schema()(discovery_payload)
self._config = config
self._setup_from_config(self._config)
# Prepare MQTT subscriptions
self.attributes_prepare_discovery_update(config)
self.availability_prepare_discovery_update(config)
self.device_info_discovery_update(config)
self._prepare_subscribe_topics()
# Finalize MQTT subscriptions
await self.attributes_discovery_update(config)
await self.availability_discovery_update(config)
await self.device_info_discovery_update(config)
await self._subscribe_topics()
self.async_write_ha_state()
async def async_will_remove_from_hass(self):
"""Unsubscribe when removed."""
self._sub_state = await subscription.async_unsubscribe_topics(
self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state
)
await MqttAttributes.async_will_remove_from_hass(self)
@ -687,6 +717,10 @@ class MqttEntity(
def _setup_from_config(self, config):
"""(Re)Setup the entity."""
@abstractmethod
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@abstractmethod
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View File

@ -162,7 +162,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
).async_render_with_possible_json_value,
}
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -200,7 +200,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -213,6 +213,10 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()):
self._current_number = last_state.state

View File

@ -128,7 +128,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
).async_render_with_possible_json_value,
}
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -156,7 +156,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -169,6 +169,10 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()):
self._attr_current_option = last_state.state

View File

@ -213,7 +213,7 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
topics = {}
@ -304,10 +304,14 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
"encoding": self._config[CONF_ENCODING] or None,
}
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_):
"""Triggered when value is expired."""

View File

@ -214,7 +214,7 @@ class MqttSiren(MqttEntity, SirenEntity):
entity=self,
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -274,7 +274,7 @@ class MqttSiren(MqttEntity, SirenEntity):
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -287,6 +287,10 @@ class MqttSiren(MqttEntity, SirenEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def assumed_state(self):
"""Return true if we do optimistic updates."""

View File

@ -1,13 +1,12 @@
"""Helper to handle a set of topics to subscribe to."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Coroutine
from typing import Any
import attr
from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from . import debug_info
from .. import mqtt
@ -22,11 +21,12 @@ class EntitySubscription:
hass: HomeAssistant = attr.ib()
topic: str = attr.ib()
message_callback: MessageCallbackType = attr.ib()
subscribe_task: Coroutine | None = attr.ib()
unsubscribe_callback: Callable[[], None] | None = attr.ib()
qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8")
async def resubscribe_if_necessary(self, hass, other):
def resubscribe_if_necessary(self, hass, other):
"""Re-subscribe to the new topic if necessary."""
if not self._should_resubscribe(other):
self.unsubscribe_callback = other.unsubscribe_callback
@ -46,33 +46,41 @@ class EntitySubscription:
# Prepare debug data
debug_info.add_subscription(self.hass, self.message_callback, self.topic)
self.unsubscribe_callback = await mqtt.async_subscribe(
self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding
)
async def subscribe(self):
"""Subscribe to a topic."""
if not self.subscribe_task:
return
self.unsubscribe_callback = await self.subscribe_task
def _should_resubscribe(self, other):
"""Check if we should re-subscribe to the topic using the old state."""
if other is None:
return True
return (self.topic, self.qos, self.encoding) != (
return (self.topic, self.qos, self.encoding,) != (
other.topic,
other.qos,
other.encoding,
)
@bind_hass
async def async_subscribe_topics(
def async_prepare_subscribe_topics(
hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None,
topics: dict[str, Any],
) -> dict[str, EntitySubscription]:
"""(Re)Subscribe to a set of MQTT topics.
"""Prepare (re)subscribe to a set of MQTT topics.
State is kept in sub_state and a dictionary mapping from the subscription
key to the subscription state.
After this function has been called, async_subscribe_topics must be called to
finalize any new subscriptions.
Please note that the sub state must not be shared between multiple
sets of topics. Every call to async_subscribe_topics must always
contain _all_ the topics the subscription state should manage.
@ -88,10 +96,11 @@ async def async_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"),
hass=hass,
subscribe_task=None,
)
# Get the current subscription state
current = current_subscriptions.pop(key, None)
await requested.resubscribe_if_necessary(hass, current)
requested.resubscribe_if_necessary(hass, current)
new_state[key] = requested
# Go through all remaining subscriptions and unsubscribe them
@ -106,9 +115,19 @@ async def async_subscribe_topics(
return new_state
@bind_hass
async def async_unsubscribe_topics(
async def async_subscribe_topics(
hass: HomeAssistant,
sub_state: dict[str, EntitySubscription] | None,
) -> None:
"""(Re)Subscribe to a set of MQTT topics."""
if sub_state is None:
return
for sub in sub_state.values():
await sub.subscribe()
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return await async_subscribe_topics(hass, sub_state, {})
return async_prepare_subscribe_topics(hass, sub_state, {})

View File

@ -132,7 +132,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
@ -151,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -164,6 +164,10 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()):
self._state = last_state.state == STATE_ON

View File

@ -175,7 +175,7 @@ class MQTTTagScanner:
await self.hass.components.tag.async_scan_tag(tag_id, self.device_id)
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -186,6 +186,7 @@ class MQTTTagScanner:
}
},
)
await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def device_removed(self, event):
"""Handle the removal of a device."""
@ -207,7 +208,7 @@ class MQTTTagScanner:
self._remove_discovery()
mqtt.publish(self.hass, discovery_topic, "", retain=True)
self._sub_state = await subscription.async_unsubscribe_topics(
self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state
)
if self.device_id:

View File

@ -240,7 +240,7 @@ class MqttVacuum(MqttEntity, VacuumEntity):
)
}
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
for tpl in self._templates.values():
if tpl is not None:
@ -325,7 +325,7 @@ class MqttVacuum(MqttEntity, VacuumEntity):
self.async_write_ha_state()
topics_list = {topic for topic in self._state_topics.values() if topic}
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
@ -339,6 +339,10 @@ class MqttVacuum(MqttEntity, VacuumEntity):
},
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def is_on(self):
"""Return true if vacuum is on."""

View File

@ -197,7 +197,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
)
}
async def _subscribe_topics(self):
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
topics = {}
@ -219,10 +219,14 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
self._sub_state = await subscription.async_subscribe_topics(
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property
def state(self):
"""Return state of vacuum."""

View File

@ -19,6 +19,7 @@ import voluptuous as vol
from homeassistant.components import mqtt, websocket_api
from homeassistant.components.mqtt.subscription import (
async_prepare_subscribe_topics,
async_subscribe_topics,
async_unsubscribe_topics,
)
@ -62,10 +63,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
for topic in topics.values():
if "msg_callback" in topic and "event_loop_safe" in topic:
topic["msg_callback"] = callback(topic["msg_callback"])
return await async_subscribe_topics(hass, sub_state, topics)
sub_state = async_prepare_subscribe_topics(hass, sub_state, topics)
await async_subscribe_topics(hass, sub_state)
return sub_state
async def _unsubscribe_topics(sub_state: dict | None) -> dict:
return await async_unsubscribe_topics(hass, sub_state)
return async_unsubscribe_topics(hass, sub_state)
tasmota_mqtt = TasmotaMQTTClient(_publish, _subscribe_topics, _unsubscribe_topics)

View File

@ -2,6 +2,7 @@
from unittest.mock import ANY
from homeassistant.components.mqtt.subscription import (
async_prepare_subscribe_topics,
async_subscribe_topics,
async_unsubscribe_topics,
)
@ -27,7 +28,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
calls2.append(args)
sub_state = None
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{
@ -35,6 +36,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
"test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2},
},
)
await async_subscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload1")
assert len(calls1) == 1
@ -48,7 +50,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
assert calls2[0][0].topic == "test-topic2"
assert calls2[0][0].payload == "test-payload2"
await async_unsubscribe_topics(hass, sub_state)
async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload")
async_fire_mqtt_message(hass, "test-topic2", "test-payload")
@ -74,7 +76,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
calls2.append(args)
sub_state = None
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{
@ -82,6 +84,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
"test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2},
},
)
await async_subscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls1) == 1
@ -91,11 +94,12 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
assert len(calls1) == 1
assert len(calls2) == 1
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{"test_topic1": {"topic": "test-topic1_1", "msg_callback": record_calls1}},
)
await async_subscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload")
async_fire_mqtt_message(hass, "test-topic2", "test-payload")
@ -108,7 +112,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
assert calls1[1][0].payload == "test-payload"
assert len(calls2) == 1
await async_unsubscribe_topics(hass, sub_state)
async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1_1", "test-payload")
async_fire_mqtt_message(hass, "test-topic2", "test-payload")
@ -126,11 +130,12 @@ async def test_qos_encoding_default(hass, mqtt_mock, caplog):
pass
sub_state = None
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{"test_topic1": {"topic": "test-topic1", "msg_callback": msg_callback}},
)
await async_subscribe_topics(hass, sub_state)
mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 0, "utf-8")
@ -143,7 +148,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
pass
sub_state = None
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{
@ -155,6 +160,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
}
},
)
await async_subscribe_topics(hass, sub_state)
mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 1, "utf-16")
@ -169,27 +175,29 @@ async def test_no_change(hass, mqtt_mock, caplog):
calls.append(args)
sub_state = None
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}},
)
await async_subscribe_topics(hass, sub_state)
subscribe_call_count = mqtt_mock.async_subscribe.call_count
async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls) == 1
sub_state = await async_subscribe_topics(
sub_state = async_prepare_subscribe_topics(
hass,
sub_state,
{"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}},
)
await async_subscribe_topics(hass, sub_state)
assert subscribe_call_count == mqtt_mock.async_subscribe.call_count
async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls) == 2
await async_unsubscribe_topics(hass, sub_state)
async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls) == 2