diff --git a/homeassistant/components/fan/mqtt.py b/homeassistant/components/fan/mqtt.py index 1ff04cd913a..505a6e90720 100644 --- a/homeassistant/components/fan/mqtt.py +++ b/homeassistant/components/fan/mqtt.py @@ -5,7 +5,6 @@ For more details about this platform, please refer to the documentation https://home-assistant.io/components/fan.mqtt/ """ import logging -from typing import Optional import voluptuous as vol @@ -18,7 +17,7 @@ from homeassistant.components.mqtt import ( ATTR_DISCOVERY_HASH, CONF_AVAILABILITY_TOPIC, CONF_STATE_TOPIC, CONF_COMMAND_TOPIC, CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, MqttAvailability, MqttDiscoveryUpdate, - MqttEntityDeviceInfo) + MqttEntityDeviceInfo, subscription) import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.typing import HomeAssistantType, ConfigType @@ -107,40 +106,7 @@ async def _async_setup_entity(hass, config, async_add_entities, discovery_hash=None): """Set up the MQTT fan.""" async_add_entities([MqttFan( - config.get(CONF_NAME), - { - key: config.get(key) for key in ( - CONF_STATE_TOPIC, - CONF_COMMAND_TOPIC, - CONF_SPEED_STATE_TOPIC, - CONF_SPEED_COMMAND_TOPIC, - CONF_OSCILLATION_STATE_TOPIC, - CONF_OSCILLATION_COMMAND_TOPIC, - ) - }, - { - CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE), - ATTR_SPEED: config.get(CONF_SPEED_VALUE_TEMPLATE), - OSCILLATION: config.get(CONF_OSCILLATION_VALUE_TEMPLATE) - }, - config.get(CONF_QOS), - config.get(CONF_RETAIN), - { - STATE_ON: config.get(CONF_PAYLOAD_ON), - STATE_OFF: config.get(CONF_PAYLOAD_OFF), - OSCILLATE_ON_PAYLOAD: config.get(CONF_PAYLOAD_OSCILLATION_ON), - OSCILLATE_OFF_PAYLOAD: config.get(CONF_PAYLOAD_OSCILLATION_OFF), - SPEED_LOW: config.get(CONF_PAYLOAD_LOW_SPEED), - SPEED_MEDIUM: config.get(CONF_PAYLOAD_MEDIUM_SPEED), - SPEED_HIGH: config.get(CONF_PAYLOAD_HIGH_SPEED), - }, - config.get(CONF_SPEED_LIST), - config.get(CONF_OPTIMISTIC), - config.get(CONF_AVAILABILITY_TOPIC), - config.get(CONF_PAYLOAD_AVAILABLE), - config.get(CONF_PAYLOAD_NOT_AVAILABLE), - config.get(CONF_UNIQUE_ID), - config.get(CONF_DEVICE), + config, discovery_hash, )]) @@ -149,43 +115,102 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, FanEntity): """A MQTT fan component.""" - def __init__(self, name, topic, templates, qos, retain, payload, - speed_list, optimistic, availability_topic, payload_available, - payload_not_available, unique_id: Optional[str], - device_config: Optional[ConfigType], discovery_hash): + def __init__(self, config, discovery_hash): """Initialize the MQTT fan.""" - MqttAvailability.__init__(self, availability_topic, qos, - payload_available, payload_not_available) - MqttDiscoveryUpdate.__init__(self, discovery_hash) - MqttEntityDeviceInfo.__init__(self, device_config) - self._name = name - self._topic = topic - self._qos = qos - self._retain = retain - self._payload = payload - self._templates = templates - self._speed_list = speed_list - self._optimistic = optimistic or topic[CONF_STATE_TOPIC] is None - self._optimistic_oscillation = ( - optimistic or topic[CONF_OSCILLATION_STATE_TOPIC] is None) - self._optimistic_speed = ( - optimistic or topic[CONF_SPEED_STATE_TOPIC] is None) self._state = False self._speed = None self._oscillation = None self._supported_features = 0 - self._supported_features |= (topic[CONF_OSCILLATION_STATE_TOPIC] - is not None and SUPPORT_OSCILLATE) - self._supported_features |= (topic[CONF_SPEED_STATE_TOPIC] - is not None and SUPPORT_SET_SPEED) - self._unique_id = unique_id - self._discovery_hash = discovery_hash + self._sub_state = None + + self._name = None + self._topic = None + self._qos = None + self._retain = None + self._payload = None + self._templates = None + self._speed_list = None + self._optimistic = None + self._optimistic_oscillation = None + self._optimistic_speed = None + self._unique_id = None + + # Load config + self._setup_from_config(config) + + availability_topic = config.get(CONF_AVAILABILITY_TOPIC) + payload_available = config.get(CONF_PAYLOAD_AVAILABLE) + payload_not_available = config.get(CONF_PAYLOAD_NOT_AVAILABLE) + device_config = config.get(CONF_DEVICE) + + MqttAvailability.__init__(self, availability_topic, self._qos, + payload_available, payload_not_available) + MqttDiscoveryUpdate.__init__(self, discovery_hash, + self.discovery_update) + MqttEntityDeviceInfo.__init__(self, device_config) async def async_added_to_hass(self): """Subscribe to MQTT events.""" await MqttAvailability.async_added_to_hass(self) await MqttDiscoveryUpdate.async_added_to_hass(self) + await self._subscribe_topics() + async def discovery_update(self, discovery_payload): + """Handle updated discovery message.""" + config = PLATFORM_SCHEMA(discovery_payload) + self._setup_from_config(config) + await self.availability_discovery_update(config) + await self._subscribe_topics() + self.async_schedule_update_ha_state() + + def _setup_from_config(self, config): + """(Re)Setup the entity.""" + self._name = config.get(CONF_NAME) + self._topic = { + key: config.get(key) for key in ( + CONF_STATE_TOPIC, + CONF_COMMAND_TOPIC, + CONF_SPEED_STATE_TOPIC, + CONF_SPEED_COMMAND_TOPIC, + CONF_OSCILLATION_STATE_TOPIC, + CONF_OSCILLATION_COMMAND_TOPIC, + ) + } + self._templates = { + CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE), + ATTR_SPEED: config.get(CONF_SPEED_VALUE_TEMPLATE), + OSCILLATION: config.get(CONF_OSCILLATION_VALUE_TEMPLATE) + } + self._qos = config.get(CONF_QOS) + self._retain = config.get(CONF_RETAIN) + self._payload = { + STATE_ON: config.get(CONF_PAYLOAD_ON), + STATE_OFF: config.get(CONF_PAYLOAD_OFF), + OSCILLATE_ON_PAYLOAD: config.get(CONF_PAYLOAD_OSCILLATION_ON), + OSCILLATE_OFF_PAYLOAD: config.get(CONF_PAYLOAD_OSCILLATION_OFF), + SPEED_LOW: config.get(CONF_PAYLOAD_LOW_SPEED), + SPEED_MEDIUM: config.get(CONF_PAYLOAD_MEDIUM_SPEED), + SPEED_HIGH: config.get(CONF_PAYLOAD_HIGH_SPEED), + } + self._speed_list = config.get(CONF_SPEED_LIST) + optimistic = config.get(CONF_OPTIMISTIC) + self._optimistic = optimistic or self._topic[CONF_STATE_TOPIC] is None + self._optimistic_oscillation = ( + optimistic or self._topic[CONF_OSCILLATION_STATE_TOPIC] is None) + self._optimistic_speed = ( + optimistic or self._topic[CONF_SPEED_STATE_TOPIC] is None) + + self._supported_features = 0 + self._supported_features |= (self._topic[CONF_OSCILLATION_STATE_TOPIC] + is not None and SUPPORT_OSCILLATE) + self._supported_features |= (self._topic[CONF_SPEED_STATE_TOPIC] + is not None and SUPPORT_SET_SPEED) + + self._unique_id = config.get(CONF_UNIQUE_ID) + + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + topics = {} templates = {} for key, tpl in list(self._templates.items()): if tpl is None: @@ -205,9 +230,10 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, self.async_schedule_update_ha_state() if self._topic[CONF_STATE_TOPIC] is not None: - await mqtt.async_subscribe( - self.hass, self._topic[CONF_STATE_TOPIC], state_received, - self._qos) + topics[CONF_STATE_TOPIC] = { + 'topic': self._topic[CONF_STATE_TOPIC], + 'msg_callback': state_received, + 'qos': self._qos} @callback def speed_received(topic, payload, qos): @@ -222,9 +248,10 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, self.async_schedule_update_ha_state() if self._topic[CONF_SPEED_STATE_TOPIC] is not None: - await mqtt.async_subscribe( - self.hass, self._topic[CONF_SPEED_STATE_TOPIC], speed_received, - self._qos) + topics[CONF_SPEED_STATE_TOPIC] = { + 'topic': self._topic[CONF_SPEED_STATE_TOPIC], + 'msg_callback': speed_received, + 'qos': self._qos} self._speed = SPEED_OFF @callback @@ -238,11 +265,21 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, self.async_schedule_update_ha_state() if self._topic[CONF_OSCILLATION_STATE_TOPIC] is not None: - await mqtt.async_subscribe( - self.hass, self._topic[CONF_OSCILLATION_STATE_TOPIC], - oscillation_received, self._qos) + topics[CONF_OSCILLATION_STATE_TOPIC] = { + 'topic': self._topic[CONF_OSCILLATION_STATE_TOPIC], + 'msg_callback': oscillation_received, + 'qos': self._qos} self._oscillation = False + self._sub_state = await subscription.async_subscribe_topics( + self.hass, self._sub_state, + topics) + + async def async_will_remove_from_hass(self): + """Unsubscribe when removed.""" + await subscription.async_unsubscribe_topics(self.hass, self._sub_state) + await MqttAvailability.async_will_remove_from_hass(self) + @property def should_poll(self): """No polling needed for a MQTT fan.""" diff --git a/tests/components/fan/test_mqtt.py b/tests/components/fan/test_mqtt.py index a3f76058c76..a3e8b0e9f32 100644 --- a/tests/components/fan/test_mqtt.py +++ b/tests/components/fan/test_mqtt.py @@ -130,6 +130,38 @@ async def test_discovery_removal_fan(hass, mqtt_mock, caplog): assert state is None +async def test_discovery_update_fan(hass, mqtt_mock, caplog): + """Test removal of discovered fan.""" + entry = MockConfigEntry(domain='mqtt') + await async_start(hass, 'homeassistant', {}, entry) + data1 = ( + '{ "name": "Beer",' + ' "command_topic": "test_topic" }' + ) + data2 = ( + '{ "name": "Milk",' + ' "command_topic": "test_topic" }' + ) + async_fire_mqtt_message(hass, 'homeassistant/fan/bla/config', + data1) + await hass.async_block_till_done() + + state = hass.states.get('fan.beer') + assert state is not None + assert state.name == 'Beer' + + async_fire_mqtt_message(hass, 'homeassistant/fan/bla/config', + data2) + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get('fan.beer') + assert state is not None + assert state.name == 'Milk' + state = hass.states.get('fan.milk') + assert state is None + + async def test_unique_id(hass): """Test unique_id option only creates one fan per id.""" await async_mock_mqtt_component(hass)