mirror of
https://github.com/home-assistant/core.git
synced 2025-08-01 17:48:26 +00:00
Merge devices on connection or identifier collision
This commit is contained in:
parent
b07453dca4
commit
8d06baf0a5
@ -15,6 +15,7 @@ from yarl import URL
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
DOMAIN,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ReleaseChannel,
|
||||
@ -75,6 +76,7 @@ class DeviceEntryDisabler(StrEnum):
|
||||
"""What disabled a device entry."""
|
||||
|
||||
CONFIG_ENTRY = "config_entry"
|
||||
DUPLICATE = "duplicate"
|
||||
INTEGRATION = "integration"
|
||||
USER = "user"
|
||||
|
||||
@ -522,12 +524,36 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)](
|
||||
for identifier in old_entry.identifiers:
|
||||
del self._identifiers[identifier]
|
||||
|
||||
def get_entries(
|
||||
self,
|
||||
identifiers: set[tuple[str, str]] | None,
|
||||
connections: set[tuple[str, str]] | None,
|
||||
) -> set[str]:
|
||||
"""Get all matching entry ids from identifiers or connections."""
|
||||
entries = set()
|
||||
if identifiers:
|
||||
entries = {
|
||||
self._identifiers[identifier].id
|
||||
for identifier in identifiers
|
||||
if identifier in self._identifiers
|
||||
}
|
||||
if not connections:
|
||||
return entries
|
||||
return entries | {
|
||||
self._connections[connection].id
|
||||
for connection in _normalize_connections(connections)
|
||||
if connection in self._connections
|
||||
}
|
||||
|
||||
def get_entry(
|
||||
self,
|
||||
identifiers: set[tuple[str, str]] | None,
|
||||
connections: set[tuple[str, str]] | None,
|
||||
) -> _EntryTypeT | None:
|
||||
"""Get entry from identifiers or connections."""
|
||||
"""Get the first matching entry from identifiers or connections.
|
||||
|
||||
Identifiers are tried first, then connections.
|
||||
"""
|
||||
if identifiers:
|
||||
for identifier in identifiers:
|
||||
if identifier in self._identifiers:
|
||||
@ -754,9 +780,11 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
else:
|
||||
connections = _normalize_connections(connections)
|
||||
|
||||
device = self.async_get_device(identifiers=identifiers, connections=connections)
|
||||
device_ids = self.devices.get_entries(identifiers, connections)
|
||||
|
||||
if device is None:
|
||||
if len(device_ids) > 1:
|
||||
device = self._merge_devices(device_ids)
|
||||
elif not device_ids:
|
||||
deleted_device = self._async_get_deleted_device(identifiers, connections)
|
||||
if deleted_device is None:
|
||||
device = DeviceEntry(is_new=True)
|
||||
@ -769,6 +797,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
# If creating a new device, default to the config entry name
|
||||
if device_info_type == "primary" and (not name or name is UNDEFINED):
|
||||
name = config_entry.title
|
||||
else:
|
||||
device = self.devices[next(iter(device_ids))]
|
||||
|
||||
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
|
||||
manufacturer = default_manufacturer
|
||||
@ -796,7 +826,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
)
|
||||
entry_type = DeviceEntryType(entry_type)
|
||||
|
||||
device = self.async_update_device(
|
||||
updated_device = self.async_update_device(
|
||||
device.id,
|
||||
allow_collisions=True,
|
||||
add_config_entry_id=config_entry_id,
|
||||
@ -819,8 +849,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
|
||||
# This is safe because _async_update_device will always return a device
|
||||
# in this use case.
|
||||
assert device
|
||||
return device
|
||||
assert updated_device
|
||||
return updated_device
|
||||
|
||||
@callback
|
||||
def async_update_device( # noqa: C901
|
||||
@ -1110,6 +1140,39 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
)
|
||||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def _merge_devices(self, device_ids: set[str]) -> DeviceEntry:
|
||||
# Pick a device to be the main device. For now, we just pick the first
|
||||
# device in the set
|
||||
main_device = self.devices[next(iter(device_ids))]
|
||||
|
||||
merged_connections = set()
|
||||
merged_identifiers = set()
|
||||
|
||||
# Disable other devices, and clear their connections and identifiers
|
||||
for device_id in device_ids:
|
||||
device = self.devices[device_id]
|
||||
merged_connections |= device.connections
|
||||
merged_identifiers |= device.identifiers
|
||||
|
||||
if device.id == main_device.id:
|
||||
continue
|
||||
|
||||
self.async_update_device(
|
||||
device.id,
|
||||
disabled_by=DeviceEntryDisabler.DUPLICATE,
|
||||
new_connections=set(),
|
||||
new_identifiers={(DOMAIN, device.id)},
|
||||
)
|
||||
|
||||
self.async_update_device(
|
||||
main_device.id,
|
||||
new_connections=merged_connections,
|
||||
new_identifiers=merged_identifiers,
|
||||
)
|
||||
|
||||
return main_device
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the device registry."""
|
||||
async_setup_cleanup(self.hass, self)
|
||||
|
@ -2501,7 +2501,14 @@ def test_all() -> None:
|
||||
help_test_all(dr)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("enum"), list(dr.DeviceEntryDisabler))
|
||||
@pytest.mark.parametrize(
|
||||
("enum"),
|
||||
[
|
||||
enum
|
||||
for enum in dr.DeviceEntryDisabler
|
||||
if enum != dr.DeviceEntryDisabler.DUPLICATE
|
||||
],
|
||||
)
|
||||
def test_deprecated_constants(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
enum: dr.DeviceEntryDisabler,
|
||||
@ -2903,7 +2910,27 @@ async def test_device_registry_connections_collision(
|
||||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
|
||||
|
||||
# One of the devices should now:
|
||||
# - Be disabled
|
||||
# - Have all its connections removed
|
||||
# - Have a single identifier
|
||||
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
|
||||
main_device = device3_refetched
|
||||
duplicate_device = device1_refetched
|
||||
else:
|
||||
main_device = device1_refetched
|
||||
duplicate_device = device3_refetched
|
||||
|
||||
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
|
||||
assert main_device.disabled_by is None
|
||||
assert duplicate_device.connections == set()
|
||||
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
|
||||
assert main_device.connections == {
|
||||
(dr.CONNECTION_NETWORK_MAC, "ee:ee:ee:ee:ee:ee"),
|
||||
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||
}
|
||||
assert main_device.identifiers == {("bridgeid", "0123")}
|
||||
|
||||
|
||||
async def test_device_registry_identifiers_collision(
|
||||
@ -2979,7 +3006,24 @@ async def test_device_registry_identifiers_collision(
|
||||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
|
||||
|
||||
# One of the devices should now:
|
||||
# - Be disabled
|
||||
# - Have all its connections removed
|
||||
# - Have a single identifier
|
||||
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
|
||||
main_device = device3_refetched
|
||||
duplicate_device = device1_refetched
|
||||
else:
|
||||
main_device = device1_refetched
|
||||
duplicate_device = device3_refetched
|
||||
|
||||
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
|
||||
assert main_device.disabled_by is None
|
||||
assert duplicate_device.connections == set()
|
||||
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
|
||||
assert main_device.connections == set()
|
||||
assert main_device.identifiers == {("bridgeid", "0123"), ("bridgeid", "4567")}
|
||||
|
||||
|
||||
async def test_primary_config_entry(
|
||||
|
Loading…
x
Reference in New Issue
Block a user