diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index 8be8d311d9b..26101f32f89 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -5,45 +5,90 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ import logging +from typing import Any, Callable, Dict, Optional + +import attr from homeassistant.components import mqtt -from homeassistant.components.mqtt import DEFAULT_QOS -from homeassistant.loader import bind_hass +from homeassistant.components.mqtt import DEFAULT_QOS, MessageCallbackType from homeassistant.helpers.typing import HomeAssistantType +from homeassistant.loader import bind_hass _LOGGER = logging.getLogger(__name__) +@attr.s(slots=True) +class EntitySubscription: + """Class to hold data about an active entity topic subscription.""" + + topic = attr.ib(type=str) + message_callback = attr.ib(type=MessageCallbackType) + unsubscribe_callback = attr.ib(type=Optional[Callable[[], None]]) + qos = attr.ib(type=int, default=0) + encoding = attr.ib(type=str, default='utf-8') + + async def resubscribe_if_necessary(self, hass, other): + """Re-subscribe to the new topic if necessary.""" + if not self._should_resubscribe(other): + return + + if other is not None and other.unsubscribe_callback is not None: + other.unsubscribe_callback() + + if self.topic is None: + # We were asked to remove the subscription or not to create it + return + + self.unsubscribe_callback = await mqtt.async_subscribe( + hass, self.topic, self.message_callback, + self.qos, self.encoding + ) + + def _should_resubscribe(self, other): + """Check if we should re-subscribe to the topic using the old state.""" + if other is None: + return True + + return (self.topic, self.qos, self.encoding) != \ + (other.topic, other.qos, other.encoding) + + @bind_hass -async def async_subscribe_topics(hass: HomeAssistantType, sub_state: dict, - topics: dict): +async def async_subscribe_topics(hass: HomeAssistantType, + new_state: Optional[Dict[str, + EntitySubscription]], + topics: Dict[str, Any]): """(Re)Subscribe to a set of MQTT topics. - State is kept in sub_state. + State is kept in sub_state and a dictionary mapping from the subscription + key to the subscription state. + + Please note that the sub state must not be shared between multiple + sets of topics. Every call to async_subscribe_topics must always + contain _all_ the topics the subscription state should manage. """ - cur_state = sub_state if sub_state is not None else {} - sub_state = {} - for key in topics: - topic = topics[key].get('topic', None) - msg_callback = topics[key].get('msg_callback', None) - qos = topics[key].get('qos', DEFAULT_QOS) - encoding = topics[key].get('encoding', 'utf-8') - topic = (topic, msg_callback, qos, encoding) - (cur_topic, unsub) = cur_state.pop( - key, ((None, None, None, None), None)) + current_subscriptions = new_state if new_state is not None else {} + new_state = {} + for key, value in topics.items(): + # Extract the new requested subscription + requested = EntitySubscription( + topic=value.get('topic', None), + message_callback=value.get('msg_callback', None), + unsubscribe_callback=None, + qos=value.get('qos', DEFAULT_QOS), + encoding=value.get('encoding', 'utf-8'), + ) + # Get the current subscription state + current = current_subscriptions.pop(key, None) + await requested.resubscribe_if_necessary(hass, current) + new_state[key] = requested - if topic != cur_topic and topic[0] is not None: - if unsub is not None: - unsub() - unsub = await mqtt.async_subscribe( - hass, topic[0], topic[1], topic[2], topic[3]) - sub_state[key] = (topic, unsub) + # Go through all remaining subscriptions and unsubscribe them + for remaining in current_subscriptions.values(): + if remaining.unsubscribe_callback is not None: + remaining.unsubscribe_callback() - for key, (topic, unsub) in list(cur_state.items()): - if unsub is not None: - unsub() - - return sub_state + return new_state @bind_hass