From 8d06baf0a57cf6e80b08e5eed2a41f375a6904cd Mon Sep 17 00:00:00 2001 From: Erik Date: Wed, 26 Jun 2024 13:08:24 +0200 Subject: [PATCH] Merge devices on connection or identifier collision --- homeassistant/helpers/device_registry.py | 75 ++++++++++++++++++++++-- tests/helpers/test_device_registry.py | 50 +++++++++++++++- 2 files changed, 116 insertions(+), 9 deletions(-) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index cfafa63ec3a..4f1096e633a 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -15,6 +15,7 @@ from yarl import URL from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import ( + DOMAIN, Event, HomeAssistant, ReleaseChannel, @@ -75,6 +76,7 @@ class DeviceEntryDisabler(StrEnum): """What disabled a device entry.""" CONFIG_ENTRY = "config_entry" + DUPLICATE = "duplicate" INTEGRATION = "integration" USER = "user" @@ -522,12 +524,36 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)]( for identifier in old_entry.identifiers: del self._identifiers[identifier] + def get_entries( + self, + identifiers: set[tuple[str, str]] | None, + connections: set[tuple[str, str]] | None, + ) -> set[str]: + """Get all matching entry ids from identifiers or connections.""" + entries = set() + if identifiers: + entries = { + self._identifiers[identifier].id + for identifier in identifiers + if identifier in self._identifiers + } + if not connections: + return entries + return entries | { + self._connections[connection].id + for connection in _normalize_connections(connections) + if connection in self._connections + } + def get_entry( self, identifiers: set[tuple[str, str]] | None, connections: set[tuple[str, str]] | None, ) -> _EntryTypeT | None: - """Get entry from identifiers or connections.""" + """Get the first matching entry from identifiers or connections. + + Identifiers are tried first, then connections. + """ if identifiers: for identifier in identifiers: if identifier in self._identifiers: @@ -754,9 +780,11 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): else: connections = _normalize_connections(connections) - device = self.async_get_device(identifiers=identifiers, connections=connections) + device_ids = self.devices.get_entries(identifiers, connections) - if device is None: + if len(device_ids) > 1: + device = self._merge_devices(device_ids) + elif not device_ids: deleted_device = self._async_get_deleted_device(identifiers, connections) if deleted_device is None: device = DeviceEntry(is_new=True) @@ -769,6 +797,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): # If creating a new device, default to the config entry name if device_info_type == "primary" and (not name or name is UNDEFINED): name = config_entry.title + else: + device = self.devices[next(iter(device_ids))] if default_manufacturer is not UNDEFINED and device.manufacturer is None: manufacturer = default_manufacturer @@ -796,7 +826,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): ) entry_type = DeviceEntryType(entry_type) - device = self.async_update_device( + updated_device = self.async_update_device( device.id, allow_collisions=True, add_config_entry_id=config_entry_id, @@ -819,8 +849,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): # This is safe because _async_update_device will always return a device # in this use case. - assert device - return device + assert updated_device + return updated_device @callback def async_update_device( # noqa: C901 @@ -1110,6 +1140,39 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): ) self.async_schedule_save() + @callback + def _merge_devices(self, device_ids: set[str]) -> DeviceEntry: + # Pick a device to be the main device. For now, we just pick the first + # device in the set + main_device = self.devices[next(iter(device_ids))] + + merged_connections = set() + merged_identifiers = set() + + # Disable other devices, and clear their connections and identifiers + for device_id in device_ids: + device = self.devices[device_id] + merged_connections |= device.connections + merged_identifiers |= device.identifiers + + if device.id == main_device.id: + continue + + self.async_update_device( + device.id, + disabled_by=DeviceEntryDisabler.DUPLICATE, + new_connections=set(), + new_identifiers={(DOMAIN, device.id)}, + ) + + self.async_update_device( + main_device.id, + new_connections=merged_connections, + new_identifiers=merged_identifiers, + ) + + return main_device + async def async_load(self) -> None: """Load the device registry.""" async_setup_cleanup(self.hass, self) diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index fa57cc7557e..b374fca5463 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -2501,7 +2501,14 @@ def test_all() -> None: help_test_all(dr) -@pytest.mark.parametrize(("enum"), list(dr.DeviceEntryDisabler)) +@pytest.mark.parametrize( + ("enum"), + [ + enum + for enum in dr.DeviceEntryDisabler + if enum != dr.DeviceEntryDisabler.DUPLICATE + ], +) def test_deprecated_constants( caplog: pytest.LogCaptureFixture, enum: dr.DeviceEntryDisabler, @@ -2903,7 +2910,27 @@ async def test_device_registry_connections_collision( device3_refetched = device_registry.async_get(device3.id) device1_refetched = device_registry.async_get(device1.id) - assert not device1_refetched.connections.isdisjoint(device3_refetched.connections) + + # One of the devices should now: + # - Be disabled + # - Have all its connections removed + # - Have a single identifier + if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE: + main_device = device3_refetched + duplicate_device = device1_refetched + else: + main_device = device1_refetched + duplicate_device = device3_refetched + + assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE + assert main_device.disabled_by is None + assert duplicate_device.connections == set() + assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)} + assert main_device.connections == { + (dr.CONNECTION_NETWORK_MAC, "ee:ee:ee:ee:ee:ee"), + (dr.CONNECTION_NETWORK_MAC, "none"), + } + assert main_device.identifiers == {("bridgeid", "0123")} async def test_device_registry_identifiers_collision( @@ -2979,7 +3006,24 @@ async def test_device_registry_identifiers_collision( device3_refetched = device_registry.async_get(device3.id) device1_refetched = device_registry.async_get(device1.id) - assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers) + + # One of the devices should now: + # - Be disabled + # - Have all its connections removed + # - Have a single identifier + if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE: + main_device = device3_refetched + duplicate_device = device1_refetched + else: + main_device = device1_refetched + duplicate_device = device3_refetched + + assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE + assert main_device.disabled_by is None + assert duplicate_device.connections == set() + assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)} + assert main_device.connections == set() + assert main_device.identifiers == {("bridgeid", "0123"), ("bridgeid", "4567")} async def test_primary_config_entry(