From d40a830b892c3e0ead5a4531102bfcb96203bd47 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Sun, 13 Feb 2022 17:23:30 +0100 Subject: [PATCH] Remove entities when config entry is removed from device (#66385) * Remove entities when config entry is removed from device * Update tests/helpers/test_entity_registry.py Co-authored-by: Martin Hjelmare * Don't remove entities not connected to a config entry * Update homeassistant/helpers/entity_registry.py Co-authored-by: Martin Hjelmare Co-authored-by: Franck Nijhof Co-authored-by: Martin Hjelmare --- homeassistant/helpers/entity_registry.py | 27 +++++- tests/helpers/test_entity_registry.py | 103 +++++++++++++++++++++++ 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 36ac5cc3dde..4d4fce6e685 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -446,13 +446,31 @@ class EntityRegistry: return if event.data["action"] != "update": + # Ignore "create" action return device_registry = dr.async_get(self.hass) device = device_registry.async_get(event.data["device_id"]) - # The device may be deleted already if the event handling is late - if not device or not device.disabled: + # The device may be deleted already if the event handling is late, do nothing + # in that case. Entities will be removed when we get the "remove" event. + if not device: + return + + # Remove entities which belong to config entries no longer associated with the + # device + entities = async_entries_for_device( + self, event.data["device_id"], include_disabled_entities=True + ) + for entity in entities: + if ( + entity.config_entry_id is not None + and entity.config_entry_id not in device.config_entries + ): + self.async_remove(entity.entity_id) + + # Re-enable disabled entities if the device is no longer disabled + if not device.disabled: entities = async_entries_for_device( self, event.data["device_id"], include_disabled_entities=True ) @@ -462,11 +480,12 @@ class EntityRegistry: self.async_update_entity(entity.entity_id, disabled_by=None) return + # Ignore device disabled by config entry, this is handled by + # async_config_entry_disabled if device.disabled_by is dr.DeviceEntryDisabler.CONFIG_ENTRY: - # Handled by async_config_entry_disabled return - # Fetch entities which are not already disabled + # Fetch entities which are not already disabled and disable them entities = async_entries_for_device(self, event.data["device_id"]) for entity in entities: self.async_update_entity( diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 714ac037e2a..78c99640bd0 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -810,6 +810,109 @@ async def test_remove_device_removes_entities(hass, registry): assert not registry.async_is_registered(entry.entity_id) +async def test_remove_config_entry_from_device_removes_entities(hass, registry): + """Test that we remove entities tied to a device when config entry is removed.""" + device_registry = mock_device_registry(hass) + config_entry_1 = MockConfigEntry(domain="hue") + config_entry_2 = MockConfigEntry(domain="device_tracker") + + # Create device with two config entries + device_registry.async_get_or_create( + config_entry_id=config_entry_1.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry_2.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + assert device_entry.config_entries == { + config_entry_1.entry_id, + config_entry_2.entry_id, + } + + # Create one entity for each config entry + entry_1 = registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=config_entry_1, + device_id=device_entry.id, + ) + + entry_2 = registry.async_get_or_create( + "sensor", + "device_tracker", + "6789", + config_entry=config_entry_2, + device_id=device_entry.id, + ) + + assert registry.async_is_registered(entry_1.entity_id) + assert registry.async_is_registered(entry_2.entity_id) + + # Remove the first config entry from the device, the entity associated with it + # should be removed + device_registry.async_update_device( + device_entry.id, remove_config_entry_id=config_entry_1.entry_id + ) + await hass.async_block_till_done() + + assert device_registry.async_get(device_entry.id) + assert not registry.async_is_registered(entry_1.entity_id) + assert registry.async_is_registered(entry_2.entity_id) + + # Remove the second config entry from the device, the entity associated with it + # (and the device itself) should be removed + device_registry.async_update_device( + device_entry.id, remove_config_entry_id=config_entry_2.entry_id + ) + await hass.async_block_till_done() + + assert not device_registry.async_get(device_entry.id) + assert not registry.async_is_registered(entry_1.entity_id) + assert not registry.async_is_registered(entry_2.entity_id) + + +async def test_remove_config_entry_from_device_removes_entities_2(hass, registry): + """Test that we don't remove entities with no config entry when device is modified.""" + device_registry = mock_device_registry(hass) + config_entry_1 = MockConfigEntry(domain="hue") + config_entry_2 = MockConfigEntry(domain="device_tracker") + + # Create device with two config entries + device_registry.async_get_or_create( + config_entry_id=config_entry_1.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry_2.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + assert device_entry.config_entries == { + config_entry_1.entry_id, + config_entry_2.entry_id, + } + + # Create one entity for each config entry + entry_1 = registry.async_get_or_create( + "light", + "hue", + "5678", + device_id=device_entry.id, + ) + + assert registry.async_is_registered(entry_1.entity_id) + + # Remove the first config entry from the device + device_registry.async_update_device( + device_entry.id, remove_config_entry_id=config_entry_1.entry_id + ) + await hass.async_block_till_done() + + assert device_registry.async_get(device_entry.id) + assert registry.async_is_registered(entry_1.entity_id) + + async def test_update_device_race(hass, registry): """Test race when a device is created, updated and removed.""" device_registry = mock_device_registry(hass)