From 6d6961ae6ef2ec0624ee4388da37e54803f817df Mon Sep 17 00:00:00 2001 From: Artur Pragacz <49985303+arturpragacz@users.noreply.github.com> Date: Fri, 7 Feb 2025 14:44:44 +0100 Subject: [PATCH] Clean up colliding deleted devices when updating non-deleted devices (#135592) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix Schrödinger's devices * Address feedback * Add comment with broader context --- homeassistant/helpers/device_registry.py | 42 ++++++++++++++++++++++-- tests/helpers/test_device_registry.py | 33 +++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 975b4a2aec9..92101dd0e21 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from datetime import datetime from enum import StrEnum from functools import lru_cache @@ -561,6 +561,21 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)]( return self._connections[connection] return None + def get_entries( + self, + identifiers: set[tuple[str, str]] | None, + connections: set[tuple[str, str]] | None, + ) -> Iterable[_EntryTypeT]: + """Get entries from identifiers or connections.""" + if identifiers: + for identifier in identifiers: + if identifier in self._identifiers: + yield self._identifiers[identifier] + if connections: + for connection in _normalize_connections(connections): + if connection in self._connections: + yield self._connections[connection] + class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]): """Container for active (non-deleted) device registry entries.""" @@ -667,6 +682,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): """Check if device is deleted.""" return self.deleted_devices.get_entry(identifiers, connections) + def _async_get_deleted_devices( + self, + identifiers: set[tuple[str, str]] | None = None, + connections: set[tuple[str, str]] | None = None, + ) -> Iterable[DeletedDeviceEntry]: + """List devices that are deleted.""" + return self.deleted_devices.get_entries(identifiers, connections) + def _substitute_name_placeholders( self, domain: str, @@ -958,6 +981,9 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): new_values["config_entries"] = config_entries old_values["config_entries"] = old.config_entries + added_connections: set[tuple[str, str]] | None = None + added_identifiers: set[tuple[str, str]] | None = None + if merge_connections is not UNDEFINED: normalized_connections = self._validate_connections( device_id, @@ -966,6 +992,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): ) old_connections = old.connections if not normalized_connections.issubset(old_connections): + added_connections = normalized_connections new_values["connections"] = old_connections | normalized_connections old_values["connections"] = old_connections @@ -975,17 +1002,18 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): ) old_identifiers = old.identifiers if not merge_identifiers.issubset(old_identifiers): + added_identifiers = merge_identifiers new_values["identifiers"] = old_identifiers | merge_identifiers old_values["identifiers"] = old_identifiers if new_connections is not UNDEFINED: - new_values["connections"] = self._validate_connections( + added_connections = new_values["connections"] = self._validate_connections( device_id, new_connections, False ) old_values["connections"] = old.connections if new_identifiers is not UNDEFINED: - new_values["identifiers"] = self._validate_identifiers( + added_identifiers = new_values["identifiers"] = self._validate_identifiers( device_id, new_identifiers, False ) old_values["identifiers"] = old.identifiers @@ -1028,6 +1056,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): new = attr.evolve(old, **new_values) self.devices[device_id] = new + # NOTE: Once we solve the broader issue of duplicated devices, we might + # want to revisit it. Instead of simply removing the duplicated deleted device, + # we might want to merge the information from it into the non-deleted device. + for deleted_device in self._async_get_deleted_devices( + added_identifiers, added_connections + ): + del self.deleted_devices[deleted_device.id] + # If its only run time attributes (suggested_area) # that do not get saved we do not want to write # to disk or fire an event as we would end up diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 08b984a0477..be4ace87894 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -3378,6 +3378,39 @@ async def test_device_registry_identifiers_collision( assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers) +async def test_device_registry_deleted_device_collision( + hass: HomeAssistant, device_registry: dr.DeviceRegistry +) -> None: + """Test update collisions with deleted devices in the device registry.""" + config_entry = MockConfigEntry() + config_entry.add_to_hass(hass) + + device1 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE")}, + manufacturer="manufacturer", + model="model", + ) + assert len(device_registry.deleted_devices) == 0 + + device_registry.async_remove_device(device1.id) + assert len(device_registry.deleted_devices) == 1 + + device2 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "0123")}, + manufacturer="manufacturer", + model="model", + ) + assert len(device_registry.deleted_devices) == 1 + + device_registry.async_update_device( + device2.id, + merge_connections={(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE")}, + ) + assert len(device_registry.deleted_devices) == 0 + + async def test_primary_config_entry( hass: HomeAssistant, device_registry: dr.DeviceRegistry,