From ba6d1976dff8df2aa32726ff2acbf0ba61e5c550 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 18 Feb 2022 13:45:25 +0100 Subject: [PATCH] Improve MQTT device removal (#66766) * Improve MQTT device removal * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Martin Hjelmare * Adjust tests * Improve test coverage Co-authored-by: Martin Hjelmare --- homeassistant/components/mqtt/__init__.py | 21 +- .../components/mqtt/device_automation.py | 14 +- .../components/mqtt/device_trigger.py | 6 +- homeassistant/components/mqtt/mixins.py | 40 ++- homeassistant/components/mqtt/tag.py | 30 ++- .../mqtt/test_device_tracker_discovery.py | 19 +- tests/components/mqtt/test_device_trigger.py | 33 ++- tests/components/mqtt/test_discovery.py | 232 +++++++++++++++++- tests/components/mqtt/test_tag.py | 87 +++++-- 9 files changed, 427 insertions(+), 55 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index b97a0bc8770..23a1fcc579e 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -56,6 +56,7 @@ from homeassistant.helpers import ( event, template, ) +from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send from homeassistant.helpers.entity import Entity from homeassistant.helpers.frame import report @@ -1198,8 +1199,8 @@ def websocket_mqtt_info(hass, connection, msg): @websocket_api.websocket_command( {vol.Required("type"): "mqtt/device/remove", vol.Required("device_id"): str} ) -@callback -def websocket_remove_device(hass, connection, msg): +@websocket_api.async_response +async def websocket_remove_device(hass, connection, msg): """Delete device.""" device_id = msg["device_id"] device_registry = dr.async_get(hass) @@ -1214,7 +1215,10 @@ def websocket_remove_device(hass, connection, msg): config_entry = hass.config_entries.async_get_entry(config_entry) # Only delete the device if it belongs to an MQTT device entry if config_entry.domain == DOMAIN: - device_registry.async_remove_device(device_id) + await async_remove_config_entry_device(hass, config_entry, device) + device_registry.async_update_device( + device_id, remove_config_entry_id=config_entry.entry_id + ) connection.send_message(websocket_api.result_message(msg["id"])) return @@ -1292,3 +1296,14 @@ def async_subscribe_connection_status( def is_connected(hass: HomeAssistant) -> bool: """Return if MQTT client is connected.""" return hass.data[DATA_MQTT].connected + + +async def async_remove_config_entry_device( + hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry +) -> bool: + """Remove MQTT config entry from a device.""" + # pylint: disable-next=import-outside-toplevel + from . import device_automation + + await device_automation.async_removed_from_device(hass, device_entry.id) + return True diff --git a/homeassistant/components/mqtt/device_automation.py b/homeassistant/components/mqtt/device_automation.py index 50d6a6e4d19..cafbd66b098 100644 --- a/homeassistant/components/mqtt/device_automation.py +++ b/homeassistant/components/mqtt/device_automation.py @@ -3,8 +3,6 @@ import functools import voluptuous as vol -from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED - from . import device_trigger from .. import mqtt from .mixins import async_setup_entry_helper @@ -23,15 +21,8 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( async def async_setup_entry(hass, config_entry): """Set up MQTT device automation dynamically through MQTT discovery.""" - async def async_device_removed(event): - """Handle the removal of a device.""" - if event.data["action"] != "remove": - return - await device_trigger.async_device_removed(hass, event.data["device_id"]) - setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry) await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA) - hass.bus.async_listen(EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed) async def _async_setup_automation(hass, config, config_entry, discovery_data): @@ -40,3 +31,8 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data): await device_trigger.async_setup_trigger( hass, config, config_entry, discovery_data ) + + +async def async_removed_from_device(hass, device_id): + """Handle Mqtt removed from a device.""" + await device_trigger.async_removed_from_device(hass, device_id) diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index f621021e124..71c0a9f9364 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -222,7 +222,7 @@ async def async_setup_trigger(hass, config, config_entry, discovery_data): device_trigger.detach_trigger() clear_discovery_hash(hass, discovery_hash) remove_signal() - await cleanup_device_registry(hass, device.id) + await cleanup_device_registry(hass, device.id, config_entry.entry_id) else: # Non-empty payload: Update trigger _LOGGER.info("Updating trigger: %s", discovery_hash) @@ -275,8 +275,8 @@ async def async_setup_trigger(hass, config, config_entry, discovery_data): async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) -async def async_device_removed(hass: HomeAssistant, device_id: str): - """Handle the removal of a device.""" +async def async_removed_from_device(hass: HomeAssistant, device_id: str): + """Handle Mqtt removed from a device.""" triggers = await async_get_triggers(hass, device_id) for trig in triggers: device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID]) diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index fb25fa1e1b6..6f881a70690 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -25,7 +25,7 @@ from homeassistant.const import ( CONF_UNIQUE_ID, CONF_VALUE_TEMPLATE, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -496,7 +496,7 @@ class MqttAvailability(Entity): return self._available_latest -async def cleanup_device_registry(hass, device_id): +async def cleanup_device_registry(hass, device_id, config_entry_id): """Remove device registry entry if there are no remaining entities or triggers.""" # Local import to avoid circular dependencies # pylint: disable-next=import-outside-toplevel @@ -512,7 +512,9 @@ async def cleanup_device_registry(hass, device_id): and not await device_trigger.async_get_triggers(hass, device_id) and not tag.async_has_tags(hass, device_id) ): - device_registry.async_remove_device(device_id) + device_registry.async_update_device( + device_id, remove_config_entry_id=config_entry_id + ) class MqttDiscoveryUpdate(Entity): @@ -542,7 +544,9 @@ class MqttDiscoveryUpdate(Entity): entity_registry = er.async_get(self.hass) if entity_entry := entity_registry.async_get(self.entity_id): entity_registry.async_remove(self.entity_id) - await cleanup_device_registry(self.hass, entity_entry.device_id) + await cleanup_device_registry( + self.hass, entity_entry.device_id, entity_entry.config_entry_id + ) else: await self.async_remove(force_remove=True) @@ -817,3 +821,31 @@ class MqttEntity( def unique_id(self): """Return a unique ID.""" return self._unique_id + + +@callback +def async_removed_from_device( + hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str +) -> bool: + """Check if the passed event indicates MQTT was removed from a device.""" + device_id = event.data["device_id"] + if event.data["action"] not in ("remove", "update"): + return False + + if device_id != mqtt_device_id: + return False + + if event.data["action"] == "update": + if "config_entries" not in event.data["changes"]: + return False + device_registry = dr.async_get(hass) + device_entry = device_registry.async_get(mqtt_device_id) + if not device_entry: + # The device is already removed, do cleanup when we get "remove" event + return False + entry_id = config_entry_id + if entry_id in device_entry.config_entries: + # Not removed from device + return False + + return True diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index 186f11534b9..4f6f380e47d 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -27,6 +27,7 @@ from .mixins import ( CONF_CONNECTIONS, CONF_IDENTIFIERS, MQTT_ENTITY_DEVICE_INFO_SCHEMA, + async_removed_from_device, async_setup_entry_helper, cleanup_device_registry, device_info_from_config, @@ -126,9 +127,11 @@ class MQTTTagScanner: if not payload: # Empty payload: Remove tag scanner _LOGGER.info("Removing tag scanner: %s", discovery_hash) - await self.tear_down() + self.tear_down() if self.device_id: - await cleanup_device_registry(self.hass, self.device_id) + await cleanup_device_registry( + self.hass, self.device_id, self._config_entry.entry_id + ) else: # Non-empty payload: Update tag scanner _LOGGER.info("Updating tag scanner: %s", discovery_hash) @@ -155,7 +158,7 @@ class MQTTTagScanner: await self.subscribe_topics() if self.device_id: self._remove_device_updated = self.hass.bus.async_listen( - EVENT_DEVICE_REGISTRY_UPDATED, self.device_removed + EVENT_DEVICE_REGISTRY_UPDATED, self.device_updated ) self._remove_discovery = async_dispatcher_connect( self.hass, @@ -189,26 +192,31 @@ class MQTTTagScanner: ) await subscription.async_subscribe_topics(self.hass, self._sub_state) - async def device_removed(self, event): - """Handle the removal of a device.""" - device_id = event.data["device_id"] - if event.data["action"] != "remove" or device_id != self.device_id: + async def device_updated(self, event): + """Handle the update or removal of a device.""" + if not async_removed_from_device( + self.hass, event, self.device_id, self._config_entry.entry_id + ): return - await self.tear_down() + # Stop subscribing to discovery updates to not trigger when we clear the + # discovery topic + self.tear_down() - async def tear_down(self): + # Clear the discovery topic so the entity is not rediscovered after a restart + discovery_topic = self.discovery_data[ATTR_DISCOVERY_TOPIC] + mqtt.publish(self.hass, discovery_topic, "", retain=True) + + def tear_down(self): """Cleanup tag scanner.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] - discovery_topic = self.discovery_data[ATTR_DISCOVERY_TOPIC] clear_discovery_hash(self.hass, discovery_hash) if self.device_id: self._remove_device_updated() self._remove_discovery() - mqtt.publish(self.hass, discovery_topic, "", retain=True) self._sub_state = subscription.async_unsubscribe_topics( self.hass, self._sub_state ) diff --git a/tests/components/mqtt/test_device_tracker_discovery.py b/tests/components/mqtt/test_device_tracker_discovery.py index 4020c2beaeb..3b83581b86a 100644 --- a/tests/components/mqtt/test_device_tracker_discovery.py +++ b/tests/components/mqtt/test_device_tracker_discovery.py @@ -5,6 +5,7 @@ import pytest from homeassistant.components import device_tracker from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED from homeassistant.const import STATE_HOME, STATE_NOT_HOME, STATE_UNKNOWN +from homeassistant.setup import async_setup_component from .test_common import help_test_setting_blocked_attribute_via_mqtt_json_message @@ -183,8 +184,13 @@ async def test_device_tracker_discovery_update(hass, mqtt_mock, caplog): assert state.name == "Cider" -async def test_cleanup_device_tracker(hass, device_reg, entity_reg, mqtt_mock): +async def test_cleanup_device_tracker( + hass, hass_ws_client, device_reg, entity_reg, mqtt_mock +): """Test discvered device is cleaned up when removed from registry.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + async_fire_mqtt_message( hass, "homeassistant/device_tracker/bla/config", @@ -203,7 +209,16 @@ async def test_cleanup_device_tracker(hass, device_reg, entity_reg, mqtt_mock): state = hass.states.get("device_tracker.mqtt_unique") assert state is not None - device_reg.async_remove_device(device_entry.id) + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] await hass.async_block_till_done() await hass.async_block_till_done() diff --git a/tests/components/mqtt/test_device_trigger.py b/tests/components/mqtt/test_device_trigger.py index 972b0678ed2..8a3719f1707 100644 --- a/tests/components/mqtt/test_device_trigger.py +++ b/tests/components/mqtt/test_device_trigger.py @@ -646,9 +646,12 @@ async def test_not_fires_on_mqtt_message_after_remove_by_mqtt( async def test_not_fires_on_mqtt_message_after_remove_from_registry( - hass, device_reg, calls, mqtt_mock + hass, hass_ws_client, device_reg, calls, mqtt_mock ): """Test triggers not firing after removal.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + data1 = ( '{ "automation_type":"trigger",' ' "device":{"identifiers":["0AFFD2"]},' @@ -688,8 +691,16 @@ async def test_not_fires_on_mqtt_message_after_remove_from_registry( await hass.async_block_till_done() assert len(calls) == 1 - # Remove the device - device_reg.async_remove_device(device_entry.id) + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] await hass.async_block_till_done() async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") @@ -967,8 +978,11 @@ async def test_entity_device_info_update(hass, mqtt_mock): assert device.name == "Milk" -async def test_cleanup_trigger(hass, device_reg, entity_reg, mqtt_mock): +async def test_cleanup_trigger(hass, hass_ws_client, device_reg, entity_reg, mqtt_mock): """Test trigger discovery topic is cleaned when device is removed from registry.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + config = { "automation_type": "trigger", "topic": "test-topic", @@ -990,7 +1004,16 @@ async def test_cleanup_trigger(hass, device_reg, entity_reg, mqtt_mock): ) assert triggers[0]["type"] == "foo" - device_reg.async_remove_device(device_entry.id) + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] await hass.async_block_till_done() await hass.async_block_till_done() diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 5d94f349c58..463f3d03fff 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -1,7 +1,8 @@ """The tests for the MQTT discovery.""" +import json from pathlib import Path import re -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, call, patch import pytest @@ -19,8 +20,10 @@ from homeassistant.const import ( STATE_UNKNOWN, ) import homeassistant.core as ha +from homeassistant.setup import async_setup_component from tests.common import ( + MockConfigEntry, async_fire_mqtt_message, mock_device_registry, mock_entity_platform, @@ -565,8 +568,11 @@ async def test_duplicate_removal(hass, mqtt_mock, caplog): assert "Component has already been discovered: binary_sensor bla" not in caplog.text -async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock): - """Test discvered device is cleaned up when removed from registry.""" +async def test_cleanup_device(hass, hass_ws_client, device_reg, entity_reg, mqtt_mock): + """Test discvered device is cleaned up when entry removed from device.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + data = ( '{ "device":{"identifiers":["0AFFD2"]},' ' "state_topic": "foobar/sensor",' @@ -585,7 +591,16 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock): state = hass.states.get("sensor.mqtt_sensor") assert state is not None - device_reg.async_remove_device(device_entry.id) + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] await hass.async_block_till_done() await hass.async_block_till_done() @@ -606,6 +621,215 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock): ) +async def test_cleanup_device_mqtt(hass, device_reg, entity_reg, mqtt_mock): + """Test discvered device is cleaned up when removed through MQTT.""" + data = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) + await hass.async_block_till_done() + + # Verify device and registry entries are created + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) + assert device_entry is not None + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert entity_entry is not None + + state = hass.states.get("sensor.mqtt_sensor") + assert state is not None + + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", "") + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Verify device and registry entries are cleared + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}) + assert device_entry is None + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert entity_entry is None + + # Verify state is removed + state = hass.states.get("sensor.mqtt_sensor") + assert state is None + await hass.async_block_till_done() + + # Verify retained discovery topics have not been cleared again + mqtt_mock.async_publish.assert_not_called() + + +async def test_cleanup_device_multiple_config_entries( + hass, hass_ws_client, device_reg, entity_reg, mqtt_mock +): + """Test discovered device is cleaned up when entry removed from device.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + + mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + + sensor_config = { + "device": {"connections": [["mac", "12:34:56:AB:CD:EF"]]}, + "state_topic": "foobar/sensor", + "unique_id": "unique", + } + tag_config = { + "device": {"connections": [["mac", "12:34:56:AB:CD:EF"]]}, + "topic": "test-topic", + } + trigger_config = { + "automation_type": "trigger", + "topic": "test-topic", + "type": "foo", + "subtype": "bar", + "device": {"connections": [["mac", "12:34:56:AB:CD:EF"]]}, + } + + sensor_data = json.dumps(sensor_config) + tag_data = json.dumps(tag_config) + trigger_data = json.dumps(trigger_config) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", sensor_data) + async_fire_mqtt_message(hass, "homeassistant/tag/bla/config", tag_data) + async_fire_mqtt_message( + hass, "homeassistant/device_automation/bla/config", trigger_data + ) + await hass.async_block_till_done() + + # Verify device and registry entries are created + device_entry = device_reg.async_get_device(set(), {("mac", "12:34:56:AB:CD:EF")}) + assert device_entry is not None + assert device_entry.config_entries == { + mqtt_config_entry.entry_id, + config_entry.entry_id, + } + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert entity_entry is not None + + state = hass.states.get("sensor.mqtt_sensor") + assert state is not None + + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] + + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Verify device is still there but entity is cleared + device_entry = device_reg.async_get_device(set(), {("mac", "12:34:56:AB:CD:EF")}) + assert device_entry is not None + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert device_entry.config_entries == {config_entry.entry_id} + assert entity_entry is None + + # Verify state is removed + state = hass.states.get("sensor.mqtt_sensor") + assert state is None + await hass.async_block_till_done() + + # Verify retained discovery topic has been cleared + mqtt_mock.async_publish.assert_has_calls( + [ + call("homeassistant/sensor/bla/config", "", 0, True), + call("homeassistant/tag/bla/config", "", 0, True), + call("homeassistant/device_automation/bla/config", "", 0, True), + ], + any_order=True, + ) + + +async def test_cleanup_device_multiple_config_entries_mqtt( + hass, device_reg, entity_reg, mqtt_mock +): + """Test discovered device is cleaned up when removed through MQTT.""" + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + + mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + + sensor_config = { + "device": {"connections": [["mac", "12:34:56:AB:CD:EF"]]}, + "state_topic": "foobar/sensor", + "unique_id": "unique", + } + tag_config = { + "device": {"connections": [["mac", "12:34:56:AB:CD:EF"]]}, + "topic": "test-topic", + } + trigger_config = { + "automation_type": "trigger", + "topic": "test-topic", + "type": "foo", + "subtype": "bar", + "device": {"connections": [["mac", "12:34:56:AB:CD:EF"]]}, + } + + sensor_data = json.dumps(sensor_config) + tag_data = json.dumps(tag_config) + trigger_data = json.dumps(trigger_config) + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", sensor_data) + async_fire_mqtt_message(hass, "homeassistant/tag/bla/config", tag_data) + async_fire_mqtt_message( + hass, "homeassistant/device_automation/bla/config", trigger_data + ) + await hass.async_block_till_done() + + # Verify device and registry entries are created + device_entry = device_reg.async_get_device(set(), {("mac", "12:34:56:AB:CD:EF")}) + assert device_entry is not None + assert device_entry.config_entries == { + mqtt_config_entry.entry_id, + config_entry.entry_id, + } + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert entity_entry is not None + + state = hass.states.get("sensor.mqtt_sensor") + assert state is not None + + # Send MQTT messages to remove + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", "") + async_fire_mqtt_message(hass, "homeassistant/tag/bla/config", "") + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", "") + + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Verify device is still there but entity is cleared + device_entry = device_reg.async_get_device(set(), {("mac", "12:34:56:AB:CD:EF")}) + assert device_entry is not None + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert device_entry.config_entries == {config_entry.entry_id} + assert entity_entry is None + + # Verify state is removed + state = hass.states.get("sensor.mqtt_sensor") + assert state is None + await hass.async_block_till_done() + + # Verify retained discovery topics have not been cleared again + mqtt_mock.async_publish.assert_not_called() + + async def test_discovery_expansion(hass, mqtt_mock, caplog): """Test expansion of abbreviated discovery payload.""" data = ( diff --git a/tests/components/mqtt/test_tag.py b/tests/components/mqtt/test_tag.py index e1f3de83a0d..7d3b4f2e1b2 100644 --- a/tests/components/mqtt/test_tag.py +++ b/tests/components/mqtt/test_tag.py @@ -7,8 +7,10 @@ import pytest from homeassistant.components.device_automation import DeviceAutomationType from homeassistant.helpers import device_registry as dr +from homeassistant.setup import async_setup_component from tests.common import ( + MockConfigEntry, async_fire_mqtt_message, async_get_device_automations, mock_device_registry, @@ -355,11 +357,15 @@ async def test_not_fires_on_mqtt_message_after_remove_by_mqtt_without_device( async def test_not_fires_on_mqtt_message_after_remove_from_registry( hass, + hass_ws_client, device_reg, mqtt_mock, tag_mock, ): """Test tag scanning after removal.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + config = copy.deepcopy(DEFAULT_CONFIG_DEVICE) async_fire_mqtt_message(hass, "homeassistant/tag/bla1/config", json.dumps(config)) @@ -371,9 +377,16 @@ async def test_not_fires_on_mqtt_message_after_remove_from_registry( await hass.async_block_till_done() tag_mock.assert_called_once_with(ANY, DEFAULT_TAG_ID, device_entry.id) - # Remove the device - device_reg.async_remove_device(device_entry.id) - await hass.async_block_till_done() + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] tag_mock.reset_mock() async_fire_mqtt_message(hass, "foobar/tag_scanned", DEFAULT_TAG_SCAN) @@ -473,32 +486,78 @@ async def test_entity_device_info_update(hass, mqtt_mock): assert device.name == "Milk" -async def test_cleanup_tag(hass, device_reg, entity_reg, mqtt_mock): +async def test_cleanup_tag(hass, hass_ws_client, device_reg, entity_reg, mqtt_mock): """Test tag discovery topic is cleaned when device is removed from registry.""" - config = { + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + + mqtt_entry = hass.config_entries.async_entries("mqtt")[0] + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections=set(), + identifiers={("mqtt", "helloworld")}, + ) + + config1 = { "topic": "test-topic", "device": {"identifiers": ["helloworld"]}, } + config2 = { + "topic": "test-topic", + "device": {"identifiers": ["hejhopp"]}, + } - data = json.dumps(config) - async_fire_mqtt_message(hass, "homeassistant/tag/bla/config", data) + data1 = json.dumps(config1) + data2 = json.dumps(config2) + async_fire_mqtt_message(hass, "homeassistant/tag/bla1/config", data1) + await hass.async_block_till_done() + async_fire_mqtt_message(hass, "homeassistant/tag/bla2/config", data2) await hass.async_block_till_done() - # Verify device registry entry is created - device_entry = device_reg.async_get_device({("mqtt", "helloworld")}) - assert device_entry is not None + # Verify device registry entries are created + device_entry1 = device_reg.async_get_device({("mqtt", "helloworld")}) + assert device_entry1 is not None + assert device_entry1.config_entries == {config_entry.entry_id, mqtt_entry.entry_id} + device_entry2 = device_reg.async_get_device({("mqtt", "hejhopp")}) + assert device_entry2 is not None - device_reg.async_remove_device(device_entry.id) + # Remove other config entry from the device + device_reg.async_update_device( + device_entry1.id, remove_config_entry_id=config_entry.entry_id + ) + device_entry1 = device_reg.async_get_device({("mqtt", "helloworld")}) + assert device_entry1 is not None + assert device_entry1.config_entries == {mqtt_entry.entry_id} + device_entry2 = device_reg.async_get_device({("mqtt", "hejhopp")}) + assert device_entry2 is not None + mqtt_mock.async_publish.assert_not_called() + + # Remove MQTT from the device + await ws_client.send_json( + { + "id": 6, + "type": "mqtt/device/remove", + "device_id": device_entry1.id, + } + ) + response = await ws_client.receive_json() + assert response["success"] await hass.async_block_till_done() await hass.async_block_till_done() # Verify device registry entry is cleared - device_entry = device_reg.async_get_device({("mqtt", "helloworld")}) - assert device_entry is None + device_entry1 = device_reg.async_get_device({("mqtt", "helloworld")}) + assert device_entry1 is None + device_entry2 = device_reg.async_get_device({("mqtt", "hejhopp")}) + assert device_entry2 is not None # Verify retained discovery topic has been cleared mqtt_mock.async_publish.assert_called_once_with( - "homeassistant/tag/bla/config", "", 0, True + "homeassistant/tag/bla1/config", "", 0, True )