Fix race when handling updated MQTT discovery data (#65415)

This commit is contained in:
Erik Montnemery 2022-02-03 02:12:22 +01:00 committed by GitHub
parent 4e7cf19b5f
commit f3a89de71f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 246 additions and 108 deletions

View File

@ -173,7 +173,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
self._config[CONF_COMMAND_TEMPLATE], entity=self self._config[CONF_COMMAND_TEMPLATE], entity=self
).async_render ).async_render
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -198,7 +198,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
self._state = payload self._state = payload
self.async_write_ha_state() self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -211,6 +211,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def state(self): def state(self):
"""Return the state of the device.""" """Return the state of the device."""

View File

@ -164,7 +164,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -241,7 +241,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -254,6 +254,10 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_): def _value_is_expired(self, *_):
"""Triggered when value is expired.""" """Triggered when value is expired."""

View File

@ -95,6 +95,9 @@ class MqttButton(MqttEntity, ButtonEntity):
config.get(CONF_COMMAND_TEMPLATE), entity=self config.get(CONF_COMMAND_TEMPLATE), entity=self
).async_render ).async_render
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
async def _subscribe_topics(self): async def _subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""

View File

@ -90,7 +90,7 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -99,7 +99,7 @@ class MqttCamera(MqttEntity, Camera):
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
self._last_image = msg.payload self._last_image = msg.payload
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -112,6 +112,10 @@ class MqttCamera(MqttEntity, Camera):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def async_camera_image( async def async_camera_image(
self, width: int | None = None, height: int | None = None self, width: int | None = None, height: int | None = None
) -> bytes | None: ) -> bytes | None:

View File

@ -358,11 +358,6 @@ class MqttClimate(MqttEntity, ClimateEntity):
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
async def async_added_to_hass(self):
"""Handle being added to Home Assistant."""
await super().async_added_to_hass()
await self._subscribe_topics()
def _setup_from_config(self, config): def _setup_from_config(self, config):
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._topic = {key: config.get(key) for key in TOPIC_KEYS} self._topic = {key: config.get(key) for key in TOPIC_KEYS}
@ -417,7 +412,7 @@ class MqttClimate(MqttEntity, ClimateEntity):
self._command_templates = command_templates self._command_templates = command_templates
async def _subscribe_topics(self): # noqa: C901 def _prepare_subscribe_topics(self): # noqa: C901
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
qos = self._config[CONF_QOS] qos = self._config[CONF_QOS]
@ -615,10 +610,14 @@ class MqttClimate(MqttEntity, ClimateEntity):
add_subscription(topics, CONF_HOLD_STATE_TOPIC, handle_hold_mode_received) add_subscription(topics, CONF_HOLD_STATE_TOPIC, handle_hold_mode_received)
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def temperature_unit(self): def temperature_unit(self):
"""Return the unit of measurement.""" """Return the unit of measurement."""

View File

@ -335,7 +335,7 @@ class MqttCover(MqttEntity, CoverEntity):
config_attributes=template_config_attributes, config_attributes=template_config_attributes,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@ -460,10 +460,14 @@ class MqttCover(MqttEntity, CoverEntity):
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def assumed_state(self): def assumed_state(self):
"""Return true if we do optimistic updates.""" """Return true if we do optimistic updates."""

View File

@ -77,7 +77,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -94,7 +94,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
self.async_write_ha_state() self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -106,6 +106,10 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def latitude(self): def latitude(self):
"""Return latitude if provided in extra_state_attributes or None.""" """Return latitude if provided in extra_state_attributes or None."""

View File

@ -351,7 +351,7 @@ class MqttFan(MqttEntity, FanEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@ -479,10 +479,14 @@ class MqttFan(MqttEntity, FanEntity):
} }
self._oscillation = False self._oscillation = False
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def assumed_state(self): def assumed_state(self):
"""Return true if we do optimistic updates.""" """Return true if we do optimistic updates."""

View File

@ -267,7 +267,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@ -373,10 +373,14 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
} }
self._mode = None self._mode = None
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def assumed_state(self): def assumed_state(self):
"""Return true if we do optimistic updates.""" """Return true if we do optimistic updates."""

View File

@ -417,12 +417,10 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
"""Return True if the attribute is optimistically updated.""" """Return True if the attribute is optimistically updated."""
return getattr(self, f"_optimistic_{attribute}") return getattr(self, f"_optimistic_{attribute}")
async def _subscribe_topics(self): # noqa: C901 def _prepare_subscribe_topics(self): # noqa: C901
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
last_state = await self.async_get_last_state()
def add_topic(topic, msg_callback): def add_topic(topic, msg_callback):
"""Add a topic.""" """Add a topic."""
if self._topic[topic] is not None: if self._topic[topic] is not None:
@ -433,14 +431,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
def restore_state(attribute, condition_attribute=None):
"""Restore a state attribute."""
if condition_attribute is None:
condition_attribute = attribute
optimistic = self._is_optimistic(condition_attribute)
if optimistic and last_state and last_state.attributes.get(attribute):
setattr(self, f"_{attribute}", last_state.attributes[attribute])
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_received(msg): def state_received(msg):
@ -465,8 +455,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
elif self._optimistic and last_state:
self._state = last_state.state == STATE_ON
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -485,7 +473,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received) add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
restore_state(ATTR_BRIGHTNESS)
def _rgbx_received(msg, template, color_mode, convert_color): def _rgbx_received(msg, template, color_mode, convert_color):
"""Handle new MQTT messages for RGBW and RGBWW.""" """Handle new MQTT messages for RGBW and RGBWW."""
@ -520,8 +507,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_RGB_STATE_TOPIC, rgb_received) add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
restore_state(ATTR_RGB_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_RGB_COLOR)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -539,7 +524,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received) add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
restore_state(ATTR_RGBW_COLOR)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -557,7 +541,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received) add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
restore_state(ATTR_RGBWW_COLOR)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -574,7 +557,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received) add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
restore_state(ATTR_COLOR_MODE)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -593,7 +575,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received) add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
restore_state(ATTR_COLOR_TEMP)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -610,7 +591,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received) add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
restore_state(ATTR_EFFECT)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -630,7 +610,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
_LOGGER.debug("Failed to parse hs state update: '%s'", payload) _LOGGER.debug("Failed to parse hs state update: '%s'", payload)
add_topic(CONF_HS_STATE_TOPIC, hs_received) add_topic(CONF_HS_STATE_TOPIC, hs_received)
restore_state(ATTR_HS_COLOR)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -649,7 +628,6 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_WHITE_VALUE_STATE_TOPIC, white_value_received) add_topic(CONF_WHITE_VALUE_STATE_TOPIC, white_value_received)
restore_state(ATTR_WHITE_VALUE)
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -670,13 +648,39 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
add_topic(CONF_XY_STATE_TOPIC, xy_received) add_topic(CONF_XY_STATE_TOPIC, xy_received)
restore_state(ATTR_XY_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR)
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
def restore_state(attribute, condition_attribute=None):
"""Restore a state attribute."""
if condition_attribute is None:
condition_attribute = attribute
optimistic = self._is_optimistic(condition_attribute)
if optimistic and last_state and last_state.attributes.get(attribute):
setattr(self, f"_{attribute}", last_state.attributes[attribute])
if self._topic[CONF_STATE_TOPIC] is None and self._optimistic and last_state:
self._state = last_state.state == STATE_ON
restore_state(ATTR_BRIGHTNESS)
restore_state(ATTR_RGB_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_RGB_COLOR)
restore_state(ATTR_RGBW_COLOR)
restore_state(ATTR_RGBWW_COLOR)
restore_state(ATTR_COLOR_MODE)
restore_state(ATTR_COLOR_TEMP)
restore_state(ATTR_EFFECT)
restore_state(ATTR_HS_COLOR)
restore_state(ATTR_WHITE_VALUE)
restore_state(ATTR_XY_COLOR)
restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR)
@property @property
def brightness(self): def brightness(self):
"""Return the brightness of this light between 0..255.""" """Return the brightness of this light between 0..255."""

View File

@ -304,9 +304,8 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
except (KeyError, ValueError): except (KeyError, ValueError):
_LOGGER.warning("Invalid or incomplete color value received") _LOGGER.warning("Invalid or incomplete color value received")
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
last_state = await self.async_get_last_state()
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
@ -370,7 +369,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
if self._topic[CONF_STATE_TOPIC] is not None: if self._topic[CONF_STATE_TOPIC] is not None:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -383,6 +382,11 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state: if self._optimistic and last_state:
self._state = last_state.state == STATE_ON self._state = last_state.state == STATE_ON
last_attributes = last_state.attributes last_attributes = last_state.attributes

View File

@ -156,14 +156,12 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
or self._templates[CONF_STATE_TEMPLATE] is None or self._templates[CONF_STATE_TEMPLATE] is None
) )
async def _subscribe_topics(self): # noqa: C901 def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
for tpl in self._templates.values(): for tpl in self._templates.values():
if tpl is not None: if tpl is not None:
tpl = MqttValueTemplate(tpl, entity=self) tpl = MqttValueTemplate(tpl, entity=self)
last_state = await self.async_get_last_state()
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_received(msg): def state_received(msg):
@ -246,7 +244,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
if self._topics[CONF_STATE_TOPIC] is not None: if self._topics[CONF_STATE_TOPIC] is not None:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -259,6 +257,11 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state: if self._optimistic and last_state:
self._state = last_state.state == STATE_ON self._state = last_state.state == STATE_ON
if last_state.attributes.get(ATTR_BRIGHTNESS): if last_state.attributes.get(ATTR_BRIGHTNESS):

View File

@ -123,7 +123,7 @@ class MqttLock(MqttEntity, LockEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -142,7 +142,7 @@ class MqttLock(MqttEntity, LockEntity):
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: else:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -155,6 +155,10 @@ class MqttLock(MqttEntity, LockEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def is_locked(self): def is_locked(self):
"""Return true if lock is locked.""" """Return true if lock is locked."""

View File

@ -66,7 +66,11 @@ from .discovery import (
set_discovery_hash, set_discovery_hash,
) )
from .models import ReceiveMessage from .models import ReceiveMessage
from .subscription import async_subscribe_topics, async_unsubscribe_topics from .subscription import (
async_prepare_subscribe_topics,
async_subscribe_topics,
async_unsubscribe_topics,
)
from .util import valid_subscribe_topic from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -249,14 +253,19 @@ class MqttAttributes(Entity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Subscribe MQTT events.""" """Subscribe MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics() await self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: dict):
"""Handle updated discovery message."""
self._attributes_config = config
self._attributes_prepare_subscribe_topics()
async def attributes_discovery_update(self, config: dict): async def attributes_discovery_update(self, config: dict):
"""Handle updated discovery message.""" """Handle updated discovery message."""
self._attributes_config = config
await self._attributes_subscribe_topics() await self._attributes_subscribe_topics()
async def _attributes_subscribe_topics(self): def _attributes_prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
attr_tpl = MqttValueTemplate( attr_tpl = MqttValueTemplate(
self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self
@ -284,7 +293,7 @@ class MqttAttributes(Entity):
_LOGGER.warning("Erroneous JSON: %s", payload) _LOGGER.warning("Erroneous JSON: %s", payload)
self._attributes = None self._attributes = None
self._attributes_sub_state = await async_subscribe_topics( self._attributes_sub_state = async_prepare_subscribe_topics(
self.hass, self.hass,
self._attributes_sub_state, self._attributes_sub_state,
{ {
@ -297,9 +306,13 @@ class MqttAttributes(Entity):
}, },
) )
async def _attributes_subscribe_topics(self):
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
self._attributes_sub_state = await async_unsubscribe_topics( self._attributes_sub_state = async_unsubscribe_topics(
self.hass, self._attributes_sub_state self.hass, self._attributes_sub_state
) )
@ -322,6 +335,7 @@ class MqttAvailability(Entity):
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Subscribe MQTT events.""" """Subscribe MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._availability_prepare_subscribe_topics()
await self._availability_subscribe_topics() await self._availability_subscribe_topics()
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect) async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
@ -332,9 +346,13 @@ class MqttAvailability(Entity):
) )
) )
async def availability_discovery_update(self, config: dict): def availability_prepare_discovery_update(self, config: dict):
"""Handle updated discovery message.""" """Handle updated discovery message."""
self._availability_setup_from_config(config) self._availability_setup_from_config(config)
self._availability_prepare_subscribe_topics()
async def availability_discovery_update(self, config: dict):
"""Handle updated discovery message."""
await self._availability_subscribe_topics() await self._availability_subscribe_topics()
def _availability_setup_from_config(self, config): def _availability_setup_from_config(self, config):
@ -366,7 +384,7 @@ class MqttAvailability(Entity):
self._avail_config = config self._avail_config = config
async def _availability_subscribe_topics(self): def _availability_prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -398,12 +416,16 @@ class MqttAvailability(Entity):
for topic in self._avail_topics for topic in self._avail_topics
} }
self._availability_sub_state = await async_subscribe_topics( self._availability_sub_state = async_prepare_subscribe_topics(
self.hass, self.hass,
self._availability_sub_state, self._availability_sub_state,
topics, topics,
) )
async def _availability_subscribe_topics(self):
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state)
@callback @callback
def async_mqtt_connect(self): def async_mqtt_connect(self):
"""Update state on connection/disconnection to MQTT broker.""" """Update state on connection/disconnection to MQTT broker."""
@ -412,7 +434,7 @@ class MqttAvailability(Entity):
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
self._availability_sub_state = await async_unsubscribe_topics( self._availability_sub_state = async_unsubscribe_topics(
self.hass, self._availability_sub_state self.hass, self._availability_sub_state
) )
@ -601,7 +623,7 @@ class MqttEntityDeviceInfo(Entity):
self._device_config = device_config self._device_config = device_config
self._config_entry = config_entry self._config_entry = config_entry
async def device_info_discovery_update(self, config: dict): def device_info_discovery_update(self, config: dict):
"""Handle updated discovery message.""" """Handle updated discovery message."""
self._device_config = config.get(CONF_DEVICE) self._device_config = config.get(CONF_DEVICE)
device_registry = dr.async_get(self.hass) device_registry = dr.async_get(self.hass)
@ -657,6 +679,7 @@ class MqttEntity(
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe mqtt events.""" """Subscribe mqtt events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._prepare_subscribe_topics()
await self._subscribe_topics() await self._subscribe_topics()
async def discovery_update(self, discovery_payload): async def discovery_update(self, discovery_payload):
@ -664,15 +687,22 @@ class MqttEntity(
config = self.config_schema()(discovery_payload) config = self.config_schema()(discovery_payload)
self._config = config self._config = config
self._setup_from_config(self._config) self._setup_from_config(self._config)
# Prepare MQTT subscriptions
self.attributes_prepare_discovery_update(config)
self.availability_prepare_discovery_update(config)
self.device_info_discovery_update(config)
self._prepare_subscribe_topics()
# Finalize MQTT subscriptions
await self.attributes_discovery_update(config) await self.attributes_discovery_update(config)
await self.availability_discovery_update(config) await self.availability_discovery_update(config)
await self.device_info_discovery_update(config)
await self._subscribe_topics() await self._subscribe_topics()
self.async_write_ha_state() self.async_write_ha_state()
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
self._sub_state = await subscription.async_unsubscribe_topics( self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state self.hass, self._sub_state
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
@ -687,6 +717,10 @@ class MqttEntity(
def _setup_from_config(self, config): def _setup_from_config(self, config):
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
@abstractmethod
def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics."""
@abstractmethod @abstractmethod
async def _subscribe_topics(self): async def _subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""

View File

@ -162,7 +162,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
).async_render_with_possible_json_value, ).async_render_with_possible_json_value,
} }
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -200,7 +200,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: else:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -213,6 +213,10 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()): if self._optimistic and (last_state := await self.async_get_last_state()):
self._current_number = last_state.state self._current_number = last_state.state

View File

@ -128,7 +128,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
).async_render_with_possible_json_value, ).async_render_with_possible_json_value,
} }
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -156,7 +156,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: else:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -169,6 +169,10 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()): if self._optimistic and (last_state := await self.async_get_last_state()):
self._attr_current_option = last_state.state self._attr_current_option = last_state.state

View File

@ -213,7 +213,7 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self self._config.get(CONF_LAST_RESET_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@ -304,10 +304,14 @@ class MqttSensor(MqttEntity, SensorEntity, RestoreEntity):
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_): def _value_is_expired(self, *_):
"""Triggered when value is expired.""" """Triggered when value is expired."""

View File

@ -214,7 +214,7 @@ class MqttSiren(MqttEntity, SirenEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -274,7 +274,7 @@ class MqttSiren(MqttEntity, SirenEntity):
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: else:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -287,6 +287,10 @@ class MqttSiren(MqttEntity, SirenEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def assumed_state(self): def assumed_state(self):
"""Return true if we do optimistic updates.""" """Return true if we do optimistic updates."""

View File

@ -1,13 +1,12 @@
"""Helper to handle a set of topics to subscribe to.""" """Helper to handle a set of topics to subscribe to."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
import attr import attr
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from . import debug_info from . import debug_info
from .. import mqtt from .. import mqtt
@ -22,11 +21,12 @@ class EntitySubscription:
hass: HomeAssistant = attr.ib() hass: HomeAssistant = attr.ib()
topic: str = attr.ib() topic: str = attr.ib()
message_callback: MessageCallbackType = attr.ib() message_callback: MessageCallbackType = attr.ib()
subscribe_task: Coroutine | None = attr.ib()
unsubscribe_callback: Callable[[], None] | None = attr.ib() unsubscribe_callback: Callable[[], None] | None = attr.ib()
qos: int = attr.ib(default=0) qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8") encoding: str = attr.ib(default="utf-8")
async def resubscribe_if_necessary(self, hass, other): def resubscribe_if_necessary(self, hass, other):
"""Re-subscribe to the new topic if necessary.""" """Re-subscribe to the new topic if necessary."""
if not self._should_resubscribe(other): if not self._should_resubscribe(other):
self.unsubscribe_callback = other.unsubscribe_callback self.unsubscribe_callback = other.unsubscribe_callback
@ -46,33 +46,41 @@ class EntitySubscription:
# Prepare debug data # Prepare debug data
debug_info.add_subscription(self.hass, self.message_callback, self.topic) debug_info.add_subscription(self.hass, self.message_callback, self.topic)
self.unsubscribe_callback = await mqtt.async_subscribe( self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding hass, self.topic, self.message_callback, self.qos, self.encoding
) )
async def subscribe(self):
"""Subscribe to a topic."""
if not self.subscribe_task:
return
self.unsubscribe_callback = await self.subscribe_task
def _should_resubscribe(self, other): def _should_resubscribe(self, other):
"""Check if we should re-subscribe to the topic using the old state.""" """Check if we should re-subscribe to the topic using the old state."""
if other is None: if other is None:
return True return True
return (self.topic, self.qos, self.encoding) != ( return (self.topic, self.qos, self.encoding,) != (
other.topic, other.topic,
other.qos, other.qos,
other.encoding, other.encoding,
) )
@bind_hass def async_prepare_subscribe_topics(
async def async_subscribe_topics(
hass: HomeAssistant, hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None, new_state: dict[str, EntitySubscription] | None,
topics: dict[str, Any], topics: dict[str, Any],
) -> dict[str, EntitySubscription]: ) -> dict[str, EntitySubscription]:
"""(Re)Subscribe to a set of MQTT topics. """Prepare (re)subscribe to a set of MQTT topics.
State is kept in sub_state and a dictionary mapping from the subscription State is kept in sub_state and a dictionary mapping from the subscription
key to the subscription state. key to the subscription state.
After this function has been called, async_subscribe_topics must be called to
finalize any new subscriptions.
Please note that the sub state must not be shared between multiple Please note that the sub state must not be shared between multiple
sets of topics. Every call to async_subscribe_topics must always sets of topics. Every call to async_subscribe_topics must always
contain _all_ the topics the subscription state should manage. contain _all_ the topics the subscription state should manage.
@ -88,10 +96,11 @@ async def async_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS), qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"), encoding=value.get("encoding", "utf-8"),
hass=hass, hass=hass,
subscribe_task=None,
) )
# Get the current subscription state # Get the current subscription state
current = current_subscriptions.pop(key, None) current = current_subscriptions.pop(key, None)
await requested.resubscribe_if_necessary(hass, current) requested.resubscribe_if_necessary(hass, current)
new_state[key] = requested new_state[key] = requested
# Go through all remaining subscriptions and unsubscribe them # Go through all remaining subscriptions and unsubscribe them
@ -106,9 +115,19 @@ async def async_subscribe_topics(
return new_state return new_state
@bind_hass async def async_subscribe_topics(
async def async_unsubscribe_topics( hass: HomeAssistant,
sub_state: dict[str, EntitySubscription] | None,
) -> None:
"""(Re)Subscribe to a set of MQTT topics."""
if sub_state is None:
return
for sub in sub_state.values():
await sub.subscribe()
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]: ) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" """Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return await async_subscribe_topics(hass, sub_state, {}) return async_prepare_subscribe_topics(hass, sub_state, {})

View File

@ -132,7 +132,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -151,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: else:
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -164,6 +164,10 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()): if self._optimistic and (last_state := await self.async_get_last_state()):
self._state = last_state.state == STATE_ON self._state = last_state.state == STATE_ON

View File

@ -175,7 +175,7 @@ class MQTTTagScanner:
await self.hass.components.tag.async_scan_tag(tag_id, self.device_id) await self.hass.components.tag.async_scan_tag(tag_id, self.device_id)
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -186,6 +186,7 @@ class MQTTTagScanner:
} }
}, },
) )
await subscription.async_subscribe_topics(self.hass, self._sub_state)
async def device_removed(self, event): async def device_removed(self, event):
"""Handle the removal of a device.""" """Handle the removal of a device."""
@ -207,7 +208,7 @@ class MQTTTagScanner:
self._remove_discovery() self._remove_discovery()
mqtt.publish(self.hass, discovery_topic, "", retain=True) mqtt.publish(self.hass, discovery_topic, "", retain=True)
self._sub_state = await subscription.async_unsubscribe_topics( self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state self.hass, self._sub_state
) )
if self.device_id: if self.device_id:

View File

@ -240,7 +240,7 @@ class MqttVacuum(MqttEntity, VacuumEntity):
) )
} }
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
for tpl in self._templates.values(): for tpl in self._templates.values():
if tpl is not None: if tpl is not None:
@ -325,7 +325,7 @@ class MqttVacuum(MqttEntity, VacuumEntity):
self.async_write_ha_state() self.async_write_ha_state()
topics_list = {topic for topic in self._state_topics.values() if topic} topics_list = {topic for topic in self._state_topics.values() if topic}
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
@ -339,6 +339,10 @@ class MqttVacuum(MqttEntity, VacuumEntity):
}, },
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def is_on(self): def is_on(self):
"""Return true if vacuum is on.""" """Return true if vacuum is on."""

View File

@ -197,7 +197,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
) )
} }
async def _subscribe_topics(self): def _prepare_subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@ -219,10 +219,14 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
self._sub_state = await subscription.async_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
) )
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def state(self): def state(self):
"""Return state of vacuum.""" """Return state of vacuum."""

View File

@ -19,6 +19,7 @@ import voluptuous as vol
from homeassistant.components import mqtt, websocket_api from homeassistant.components import mqtt, websocket_api
from homeassistant.components.mqtt.subscription import ( from homeassistant.components.mqtt.subscription import (
async_prepare_subscribe_topics,
async_subscribe_topics, async_subscribe_topics,
async_unsubscribe_topics, async_unsubscribe_topics,
) )
@ -62,10 +63,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
for topic in topics.values(): for topic in topics.values():
if "msg_callback" in topic and "event_loop_safe" in topic: if "msg_callback" in topic and "event_loop_safe" in topic:
topic["msg_callback"] = callback(topic["msg_callback"]) topic["msg_callback"] = callback(topic["msg_callback"])
return await async_subscribe_topics(hass, sub_state, topics) sub_state = async_prepare_subscribe_topics(hass, sub_state, topics)
await async_subscribe_topics(hass, sub_state)
return sub_state
async def _unsubscribe_topics(sub_state: dict | None) -> dict: async def _unsubscribe_topics(sub_state: dict | None) -> dict:
return await async_unsubscribe_topics(hass, sub_state) return async_unsubscribe_topics(hass, sub_state)
tasmota_mqtt = TasmotaMQTTClient(_publish, _subscribe_topics, _unsubscribe_topics) tasmota_mqtt = TasmotaMQTTClient(_publish, _subscribe_topics, _unsubscribe_topics)

View File

@ -2,6 +2,7 @@
from unittest.mock import ANY from unittest.mock import ANY
from homeassistant.components.mqtt.subscription import ( from homeassistant.components.mqtt.subscription import (
async_prepare_subscribe_topics,
async_subscribe_topics, async_subscribe_topics,
async_unsubscribe_topics, async_unsubscribe_topics,
) )
@ -27,7 +28,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
calls2.append(args) calls2.append(args)
sub_state = None sub_state = None
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{ {
@ -35,6 +36,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
"test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2}, "test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2},
}, },
) )
await async_subscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload1") async_fire_mqtt_message(hass, "test-topic1", "test-payload1")
assert len(calls1) == 1 assert len(calls1) == 1
@ -48,7 +50,7 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
assert calls2[0][0].topic == "test-topic2" assert calls2[0][0].topic == "test-topic2"
assert calls2[0][0].payload == "test-payload2" assert calls2[0][0].payload == "test-payload2"
await async_unsubscribe_topics(hass, sub_state) async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic1", "test-payload")
async_fire_mqtt_message(hass, "test-topic2", "test-payload") async_fire_mqtt_message(hass, "test-topic2", "test-payload")
@ -74,7 +76,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
calls2.append(args) calls2.append(args)
sub_state = None sub_state = None
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{ {
@ -82,6 +84,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
"test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2}, "test_topic2": {"topic": "test-topic2", "msg_callback": record_calls2},
}, },
) )
await async_subscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls1) == 1 assert len(calls1) == 1
@ -91,11 +94,12 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
assert len(calls1) == 1 assert len(calls1) == 1
assert len(calls2) == 1 assert len(calls2) == 1
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{"test_topic1": {"topic": "test-topic1_1", "msg_callback": record_calls1}}, {"test_topic1": {"topic": "test-topic1_1", "msg_callback": record_calls1}},
) )
await async_subscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic1", "test-payload")
async_fire_mqtt_message(hass, "test-topic2", "test-payload") async_fire_mqtt_message(hass, "test-topic2", "test-payload")
@ -108,7 +112,7 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
assert calls1[1][0].payload == "test-payload" assert calls1[1][0].payload == "test-payload"
assert len(calls2) == 1 assert len(calls2) == 1
await async_unsubscribe_topics(hass, sub_state) async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1_1", "test-payload") async_fire_mqtt_message(hass, "test-topic1_1", "test-payload")
async_fire_mqtt_message(hass, "test-topic2", "test-payload") async_fire_mqtt_message(hass, "test-topic2", "test-payload")
@ -126,11 +130,12 @@ async def test_qos_encoding_default(hass, mqtt_mock, caplog):
pass pass
sub_state = None sub_state = None
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{"test_topic1": {"topic": "test-topic1", "msg_callback": msg_callback}}, {"test_topic1": {"topic": "test-topic1", "msg_callback": msg_callback}},
) )
await async_subscribe_topics(hass, sub_state)
mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 0, "utf-8") mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 0, "utf-8")
@ -143,7 +148,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
pass pass
sub_state = None sub_state = None
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{ {
@ -155,6 +160,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
} }
}, },
) )
await async_subscribe_topics(hass, sub_state)
mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 1, "utf-16") mqtt_mock.async_subscribe.assert_called_once_with("test-topic1", ANY, 1, "utf-16")
@ -169,27 +175,29 @@ async def test_no_change(hass, mqtt_mock, caplog):
calls.append(args) calls.append(args)
sub_state = None sub_state = None
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}}, {"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}},
) )
await async_subscribe_topics(hass, sub_state)
subscribe_call_count = mqtt_mock.async_subscribe.call_count subscribe_call_count = mqtt_mock.async_subscribe.call_count
async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls) == 1 assert len(calls) == 1
sub_state = await async_subscribe_topics( sub_state = async_prepare_subscribe_topics(
hass, hass,
sub_state, sub_state,
{"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}}, {"test_topic1": {"topic": "test-topic1", "msg_callback": record_calls}},
) )
await async_subscribe_topics(hass, sub_state)
assert subscribe_call_count == mqtt_mock.async_subscribe.call_count assert subscribe_call_count == mqtt_mock.async_subscribe.call_count
async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls) == 2 assert len(calls) == 2
await async_unsubscribe_topics(hass, sub_state) async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, "test-topic1", "test-payload") async_fire_mqtt_message(hass, "test-topic1", "test-payload")
assert len(calls) == 2 assert len(calls) == 2