diff --git a/homeassistant/components/lock/mqtt.py b/homeassistant/components/lock/mqtt.py index 5574c7e4e59..53bfe6ff7a1 100644 --- a/homeassistant/components/lock/mqtt.py +++ b/homeassistant/components/lock/mqtt.py @@ -14,7 +14,7 @@ from homeassistant.components.mqtt import ( ATTR_DISCOVERY_HASH, CONF_AVAILABILITY_TOPIC, CONF_COMMAND_TOPIC, CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, MqttAvailability, MqttDiscoveryUpdate, - MqttEntityDeviceInfo) + MqttEntityDeviceInfo, subscription) from homeassistant.const import ( CONF_DEVICE, CONF_NAME, CONF_OPTIMISTIC, CONF_VALUE_TEMPLATE) from homeassistant.components import mqtt, lock @@ -51,7 +51,7 @@ PLATFORM_SCHEMA = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend({ async def async_setup_platform(hass: HomeAssistantType, config: ConfigType, async_add_entities, discovery_info=None): """Set up MQTT lock panel through configuration.yaml.""" - await _async_setup_entity(hass, config, async_add_entities) + await _async_setup_entity(config, async_add_entities) async def async_setup_entry(hass, config_entry, async_add_entities): @@ -61,7 +61,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): try: discovery_hash = discovery_payload[ATTR_DISCOVERY_HASH] config = PLATFORM_SCHEMA(discovery_payload) - await _async_setup_entity(hass, config, async_add_entities, + await _async_setup_entity(config, async_add_entities, discovery_hash) except Exception: if discovery_hash: @@ -73,81 +73,83 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_discover) -async def _async_setup_entity(hass, config, async_add_entities, +async def _async_setup_entity(config, async_add_entities, discovery_hash=None): """Set up the MQTT Lock platform.""" - value_template = config.get(CONF_VALUE_TEMPLATE) - if value_template is not None: - value_template.hass = hass - - async_add_entities([MqttLock( - config.get(CONF_NAME), - config.get(CONF_STATE_TOPIC), - config.get(CONF_COMMAND_TOPIC), - config.get(CONF_QOS), - config.get(CONF_RETAIN), - config.get(CONF_PAYLOAD_LOCK), - config.get(CONF_PAYLOAD_UNLOCK), - config.get(CONF_OPTIMISTIC), - value_template, - config.get(CONF_AVAILABILITY_TOPIC), - config.get(CONF_PAYLOAD_AVAILABLE), - config.get(CONF_PAYLOAD_NOT_AVAILABLE), - config.get(CONF_UNIQUE_ID), - config.get(CONF_DEVICE), - discovery_hash, - )]) + async_add_entities([MqttLock(config, discovery_hash)]) class MqttLock(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, LockDevice): """Representation of a lock that can be toggled using MQTT.""" - def __init__(self, name, state_topic, command_topic, qos, retain, - payload_lock, payload_unlock, optimistic, value_template, - availability_topic, payload_available, payload_not_available, - unique_id, device_config, discovery_hash): + def __init__(self, config, discovery_hash): """Initialize the lock.""" + self._config = config + self._unique_id = config.get(CONF_UNIQUE_ID) + self._state = False + self._sub_state = None + self._optimistic = False + + availability_topic = config.get(CONF_AVAILABILITY_TOPIC) + payload_available = config[CONF_PAYLOAD_AVAILABLE] + payload_not_available = config[CONF_PAYLOAD_NOT_AVAILABLE] + qos = config[CONF_QOS] + device_config = config.get(CONF_DEVICE) + MqttAvailability.__init__(self, availability_topic, qos, payload_available, payload_not_available) - MqttDiscoveryUpdate.__init__(self, discovery_hash) + MqttDiscoveryUpdate.__init__(self, discovery_hash, + self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config) - self._state = False - self._name = name - self._state_topic = state_topic - self._command_topic = command_topic - self._qos = qos - self._retain = retain - self._payload_lock = payload_lock - self._payload_unlock = payload_unlock - self._optimistic = optimistic - self._template = value_template - self._discovery_hash = discovery_hash - self._unique_id = unique_id async def async_added_to_hass(self): """Subscribe to MQTT events.""" await super().async_added_to_hass() + await self._subscribe_topics() + + async def discovery_update(self, discovery_payload): + """Handle updated discovery message.""" + config = PLATFORM_SCHEMA(discovery_payload) + self._config = config + await self.availability_discovery_update(config) + await self._subscribe_topics() + self.async_schedule_update_ha_state() + + async def _subscribe_topics(self): + """(Re)Subscribe to topics.""" + value_template = self._config.get(CONF_VALUE_TEMPLATE) + if value_template is not None: + value_template.hass = self.hass @callback def message_received(topic, payload, qos): """Handle new MQTT messages.""" - if self._template is not None: - payload = self._template.async_render_with_possible_json_value( + if value_template is not None: + payload = value_template.async_render_with_possible_json_value( payload) - if payload == self._payload_lock: + if payload == self._config[CONF_PAYLOAD_LOCK]: self._state = True - elif payload == self._payload_unlock: + elif payload == self._config[CONF_PAYLOAD_UNLOCK]: self._state = False self.async_schedule_update_ha_state() - if self._state_topic is None: + if self._config.get(CONF_STATE_TOPIC) is None: # Force into optimistic mode. self._optimistic = True else: - await mqtt.async_subscribe( - self.hass, self._state_topic, message_received, self._qos) + self._sub_state = await subscription.async_subscribe_topics( + self.hass, self._sub_state, + {'state_topic': {'topic': self._config.get(CONF_STATE_TOPIC), + 'msg_callback': message_received, + 'qos': self._config[CONF_QOS]}}) + + async def async_will_remove_from_hass(self): + """Unsubscribe when removed.""" + self._sub_state = await subscription.async_unsubscribe_topics( + self.hass, self._sub_state) + await MqttAvailability.async_will_remove_from_hass(self) @property def should_poll(self): @@ -157,7 +159,7 @@ class MqttLock(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, @property def name(self): """Return the name of the lock.""" - return self._name + return self._config[CONF_NAME] @property def unique_id(self): @@ -180,8 +182,10 @@ class MqttLock(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, This method is a coroutine. """ mqtt.async_publish( - self.hass, self._command_topic, self._payload_lock, self._qos, - self._retain) + self.hass, self._config[CONF_COMMAND_TOPIC], + self._config[CONF_PAYLOAD_LOCK], + self._config[CONF_QOS], + self._config[CONF_RETAIN]) if self._optimistic: # Optimistically assume that switch has changed state. self._state = True @@ -193,8 +197,10 @@ class MqttLock(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, This method is a coroutine. """ mqtt.async_publish( - self.hass, self._command_topic, self._payload_unlock, self._qos, - self._retain) + self.hass, self._config[CONF_COMMAND_TOPIC], + self._config[CONF_PAYLOAD_UNLOCK], + self._config[CONF_QOS], + self._config[CONF_RETAIN]) if self._optimistic: # Optimistically assume that switch has changed state. self._state = False diff --git a/tests/components/lock/test_mqtt.py b/tests/components/lock/test_mqtt.py index f3b7c45d38a..83ae806d295 100644 --- a/tests/components/lock/test_mqtt.py +++ b/tests/components/lock/test_mqtt.py @@ -1,5 +1,6 @@ """The tests for the MQTT lock platform.""" import json +from unittest.mock import ANY from homeassistant.setup import async_setup_component from homeassistant.const import ( @@ -8,7 +9,8 @@ from homeassistant.components import lock, mqtt from homeassistant.components.mqtt.discovery import async_start from tests.common import ( - async_fire_mqtt_message, async_mock_mqtt_component, MockConfigEntry) + async_fire_mqtt_message, async_mock_mqtt_component, MockConfigEntry, + mock_registry) async def test_controlling_state_via_topic(hass, mqtt_mock): @@ -214,6 +216,40 @@ async def test_discovery_broken(hass, mqtt_mock, caplog): assert state is None +async def test_discovery_update_lock(hass, mqtt_mock, caplog): + """Test update of discovered lock.""" + entry = MockConfigEntry(domain=mqtt.DOMAIN) + await async_start(hass, 'homeassistant', {}, entry) + data1 = ( + '{ "name": "Beer",' + ' "state_topic": "test_topic",' + ' "command_topic": "command_topic",' + ' "availability_topic": "availability_topic1" }' + ) + data2 = ( + '{ "name": "Milk",' + ' "state_topic": "test_topic2",' + ' "command_topic": "command_topic",' + ' "availability_topic": "availability_topic2" }' + ) + async_fire_mqtt_message(hass, 'homeassistant/lock/bla/config', + data1) + await hass.async_block_till_done() + state = hass.states.get('lock.beer') + assert state is not None + assert state.name == 'Beer' + async_fire_mqtt_message(hass, 'homeassistant/lock/bla/config', + data2) + await hass.async_block_till_done() + await hass.async_block_till_done() + state = hass.states.get('lock.beer') + assert state is not None + assert state.name == 'Milk' + + state = hass.states.get('lock.milk') + assert state is None + + async def test_entity_device_info_with_identifier(hass, mqtt_mock): """Test MQTT lock device registry integration.""" entry = MockConfigEntry(domain=mqtt.DOMAIN) @@ -251,3 +287,39 @@ async def test_entity_device_info_with_identifier(hass, mqtt_mock): assert device.name == 'Beer' assert device.model == 'Glass' assert device.sw_version == '0.1-beta' + + +async def test_entity_id_update(hass, mqtt_mock): + """Test MQTT subscriptions are managed when entity_id is updated.""" + registry = mock_registry(hass, {}) + mock_mqtt = await async_mock_mqtt_component(hass) + assert await async_setup_component(hass, lock.DOMAIN, { + lock.DOMAIN: [{ + 'platform': 'mqtt', + 'name': 'beer', + 'state_topic': 'test-topic', + 'command_topic': 'test-topic', + 'availability_topic': 'avty-topic', + 'unique_id': 'TOTALLY_UNIQUE' + }] + }) + + state = hass.states.get('lock.beer') + assert state is not None + assert mock_mqtt.async_subscribe.call_count == 2 + mock_mqtt.async_subscribe.assert_any_call('test-topic', ANY, 0, 'utf-8') + mock_mqtt.async_subscribe.assert_any_call('avty-topic', ANY, 0, 'utf-8') + mock_mqtt.async_subscribe.reset_mock() + + registry.async_update_entity('lock.beer', new_entity_id='lock.milk') + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get('lock.beer') + assert state is None + + state = hass.states.get('lock.milk') + assert state is not None + assert mock_mqtt.async_subscribe.call_count == 2 + mock_mqtt.async_subscribe.assert_any_call('test-topic', ANY, 0, 'utf-8') + mock_mqtt.async_subscribe.assert_any_call('avty-topic', ANY, 0, 'utf-8')