Clean up colliding deleted devices when updating non-deleted devices (#135592)

* Fix Schrödinger's devices

* Address feedback

* Add comment with broader context
This commit is contained in:
Artur Pragacz 2025-02-07 14:44:44 +01:00 committed by GitHub
parent e340f5af8d
commit 6d6961ae6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 3 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping from collections.abc import Iterable, Mapping
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from functools import lru_cache from functools import lru_cache
@ -561,6 +561,21 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)](
return self._connections[connection] return self._connections[connection]
return None return None
def get_entries(
self,
identifiers: set[tuple[str, str]] | None,
connections: set[tuple[str, str]] | None,
) -> Iterable[_EntryTypeT]:
"""Get entries from identifiers or connections."""
if identifiers:
for identifier in identifiers:
if identifier in self._identifiers:
yield self._identifiers[identifier]
if connections:
for connection in _normalize_connections(connections):
if connection in self._connections:
yield self._connections[connection]
class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]): class ActiveDeviceRegistryItems(DeviceRegistryItems[DeviceEntry]):
"""Container for active (non-deleted) device registry entries.""" """Container for active (non-deleted) device registry entries."""
@ -667,6 +682,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
"""Check if device is deleted.""" """Check if device is deleted."""
return self.deleted_devices.get_entry(identifiers, connections) return self.deleted_devices.get_entry(identifiers, connections)
def _async_get_deleted_devices(
self,
identifiers: set[tuple[str, str]] | None = None,
connections: set[tuple[str, str]] | None = None,
) -> Iterable[DeletedDeviceEntry]:
"""List devices that are deleted."""
return self.deleted_devices.get_entries(identifiers, connections)
def _substitute_name_placeholders( def _substitute_name_placeholders(
self, self,
domain: str, domain: str,
@ -958,6 +981,9 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
new_values["config_entries"] = config_entries new_values["config_entries"] = config_entries
old_values["config_entries"] = old.config_entries old_values["config_entries"] = old.config_entries
added_connections: set[tuple[str, str]] | None = None
added_identifiers: set[tuple[str, str]] | None = None
if merge_connections is not UNDEFINED: if merge_connections is not UNDEFINED:
normalized_connections = self._validate_connections( normalized_connections = self._validate_connections(
device_id, device_id,
@ -966,6 +992,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
) )
old_connections = old.connections old_connections = old.connections
if not normalized_connections.issubset(old_connections): if not normalized_connections.issubset(old_connections):
added_connections = normalized_connections
new_values["connections"] = old_connections | normalized_connections new_values["connections"] = old_connections | normalized_connections
old_values["connections"] = old_connections old_values["connections"] = old_connections
@ -975,17 +1002,18 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
) )
old_identifiers = old.identifiers old_identifiers = old.identifiers
if not merge_identifiers.issubset(old_identifiers): if not merge_identifiers.issubset(old_identifiers):
added_identifiers = merge_identifiers
new_values["identifiers"] = old_identifiers | merge_identifiers new_values["identifiers"] = old_identifiers | merge_identifiers
old_values["identifiers"] = old_identifiers old_values["identifiers"] = old_identifiers
if new_connections is not UNDEFINED: if new_connections is not UNDEFINED:
new_values["connections"] = self._validate_connections( added_connections = new_values["connections"] = self._validate_connections(
device_id, new_connections, False device_id, new_connections, False
) )
old_values["connections"] = old.connections old_values["connections"] = old.connections
if new_identifiers is not UNDEFINED: if new_identifiers is not UNDEFINED:
new_values["identifiers"] = self._validate_identifiers( added_identifiers = new_values["identifiers"] = self._validate_identifiers(
device_id, new_identifiers, False device_id, new_identifiers, False
) )
old_values["identifiers"] = old.identifiers old_values["identifiers"] = old.identifiers
@ -1028,6 +1056,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
new = attr.evolve(old, **new_values) new = attr.evolve(old, **new_values)
self.devices[device_id] = new self.devices[device_id] = new
# NOTE: Once we solve the broader issue of duplicated devices, we might
# want to revisit it. Instead of simply removing the duplicated deleted device,
# we might want to merge the information from it into the non-deleted device.
for deleted_device in self._async_get_deleted_devices(
added_identifiers, added_connections
):
del self.deleted_devices[deleted_device.id]
# If its only run time attributes (suggested_area) # If its only run time attributes (suggested_area)
# that do not get saved we do not want to write # that do not get saved we do not want to write
# to disk or fire an event as we would end up # to disk or fire an event as we would end up

View File

@ -3378,6 +3378,39 @@ async def test_device_registry_identifiers_collision(
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers) assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
async def test_device_registry_deleted_device_collision(
hass: HomeAssistant, device_registry: dr.DeviceRegistry
) -> None:
"""Test update collisions with deleted devices in the device registry."""
config_entry = MockConfigEntry()
config_entry.add_to_hass(hass)
device1 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE")},
manufacturer="manufacturer",
model="model",
)
assert len(device_registry.deleted_devices) == 0
device_registry.async_remove_device(device1.id)
assert len(device_registry.deleted_devices) == 1
device2 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
)
assert len(device_registry.deleted_devices) == 1
device_registry.async_update_device(
device2.id,
merge_connections={(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE")},
)
assert len(device_registry.deleted_devices) == 0
async def test_primary_config_entry( async def test_primary_config_entry(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,