mirror of
https://github.com/home-assistant/core.git
synced 2025-08-02 18:18:21 +00:00
Merge devices on connection or identifier collision
This commit is contained in:
parent
b07453dca4
commit
8d06baf0a5
@ -15,6 +15,7 @@ from yarl import URL
|
|||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
|
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
|
DOMAIN,
|
||||||
Event,
|
Event,
|
||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
ReleaseChannel,
|
ReleaseChannel,
|
||||||
@ -75,6 +76,7 @@ class DeviceEntryDisabler(StrEnum):
|
|||||||
"""What disabled a device entry."""
|
"""What disabled a device entry."""
|
||||||
|
|
||||||
CONFIG_ENTRY = "config_entry"
|
CONFIG_ENTRY = "config_entry"
|
||||||
|
DUPLICATE = "duplicate"
|
||||||
INTEGRATION = "integration"
|
INTEGRATION = "integration"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
||||||
@ -522,12 +524,36 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)](
|
|||||||
for identifier in old_entry.identifiers:
|
for identifier in old_entry.identifiers:
|
||||||
del self._identifiers[identifier]
|
del self._identifiers[identifier]
|
||||||
|
|
||||||
|
def get_entries(
|
||||||
|
self,
|
||||||
|
identifiers: set[tuple[str, str]] | None,
|
||||||
|
connections: set[tuple[str, str]] | None,
|
||||||
|
) -> set[str]:
|
||||||
|
"""Get all matching entry ids from identifiers or connections."""
|
||||||
|
entries = set()
|
||||||
|
if identifiers:
|
||||||
|
entries = {
|
||||||
|
self._identifiers[identifier].id
|
||||||
|
for identifier in identifiers
|
||||||
|
if identifier in self._identifiers
|
||||||
|
}
|
||||||
|
if not connections:
|
||||||
|
return entries
|
||||||
|
return entries | {
|
||||||
|
self._connections[connection].id
|
||||||
|
for connection in _normalize_connections(connections)
|
||||||
|
if connection in self._connections
|
||||||
|
}
|
||||||
|
|
||||||
def get_entry(
|
def get_entry(
|
||||||
self,
|
self,
|
||||||
identifiers: set[tuple[str, str]] | None,
|
identifiers: set[tuple[str, str]] | None,
|
||||||
connections: set[tuple[str, str]] | None,
|
connections: set[tuple[str, str]] | None,
|
||||||
) -> _EntryTypeT | None:
|
) -> _EntryTypeT | None:
|
||||||
"""Get entry from identifiers or connections."""
|
"""Get the first matching entry from identifiers or connections.
|
||||||
|
|
||||||
|
Identifiers are tried first, then connections.
|
||||||
|
"""
|
||||||
if identifiers:
|
if identifiers:
|
||||||
for identifier in identifiers:
|
for identifier in identifiers:
|
||||||
if identifier in self._identifiers:
|
if identifier in self._identifiers:
|
||||||
@ -754,9 +780,11 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
else:
|
else:
|
||||||
connections = _normalize_connections(connections)
|
connections = _normalize_connections(connections)
|
||||||
|
|
||||||
device = self.async_get_device(identifiers=identifiers, connections=connections)
|
device_ids = self.devices.get_entries(identifiers, connections)
|
||||||
|
|
||||||
if device is None:
|
if len(device_ids) > 1:
|
||||||
|
device = self._merge_devices(device_ids)
|
||||||
|
elif not device_ids:
|
||||||
deleted_device = self._async_get_deleted_device(identifiers, connections)
|
deleted_device = self._async_get_deleted_device(identifiers, connections)
|
||||||
if deleted_device is None:
|
if deleted_device is None:
|
||||||
device = DeviceEntry(is_new=True)
|
device = DeviceEntry(is_new=True)
|
||||||
@ -769,6 +797,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
# If creating a new device, default to the config entry name
|
# If creating a new device, default to the config entry name
|
||||||
if device_info_type == "primary" and (not name or name is UNDEFINED):
|
if device_info_type == "primary" and (not name or name is UNDEFINED):
|
||||||
name = config_entry.title
|
name = config_entry.title
|
||||||
|
else:
|
||||||
|
device = self.devices[next(iter(device_ids))]
|
||||||
|
|
||||||
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
|
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
|
||||||
manufacturer = default_manufacturer
|
manufacturer = default_manufacturer
|
||||||
@ -796,7 +826,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
)
|
)
|
||||||
entry_type = DeviceEntryType(entry_type)
|
entry_type = DeviceEntryType(entry_type)
|
||||||
|
|
||||||
device = self.async_update_device(
|
updated_device = self.async_update_device(
|
||||||
device.id,
|
device.id,
|
||||||
allow_collisions=True,
|
allow_collisions=True,
|
||||||
add_config_entry_id=config_entry_id,
|
add_config_entry_id=config_entry_id,
|
||||||
@ -819,8 +849,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
|
|
||||||
# This is safe because _async_update_device will always return a device
|
# This is safe because _async_update_device will always return a device
|
||||||
# in this use case.
|
# in this use case.
|
||||||
assert device
|
assert updated_device
|
||||||
return device
|
return updated_device
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_update_device( # noqa: C901
|
def async_update_device( # noqa: C901
|
||||||
@ -1110,6 +1140,39 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
)
|
)
|
||||||
self.async_schedule_save()
|
self.async_schedule_save()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _merge_devices(self, device_ids: set[str]) -> DeviceEntry:
|
||||||
|
# Pick a device to be the main device. For now, we just pick the first
|
||||||
|
# device in the set
|
||||||
|
main_device = self.devices[next(iter(device_ids))]
|
||||||
|
|
||||||
|
merged_connections = set()
|
||||||
|
merged_identifiers = set()
|
||||||
|
|
||||||
|
# Disable other devices, and clear their connections and identifiers
|
||||||
|
for device_id in device_ids:
|
||||||
|
device = self.devices[device_id]
|
||||||
|
merged_connections |= device.connections
|
||||||
|
merged_identifiers |= device.identifiers
|
||||||
|
|
||||||
|
if device.id == main_device.id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.async_update_device(
|
||||||
|
device.id,
|
||||||
|
disabled_by=DeviceEntryDisabler.DUPLICATE,
|
||||||
|
new_connections=set(),
|
||||||
|
new_identifiers={(DOMAIN, device.id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.async_update_device(
|
||||||
|
main_device.id,
|
||||||
|
new_connections=merged_connections,
|
||||||
|
new_identifiers=merged_identifiers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return main_device
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
"""Load the device registry."""
|
"""Load the device registry."""
|
||||||
async_setup_cleanup(self.hass, self)
|
async_setup_cleanup(self.hass, self)
|
||||||
|
@ -2501,7 +2501,14 @@ def test_all() -> None:
|
|||||||
help_test_all(dr)
|
help_test_all(dr)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("enum"), list(dr.DeviceEntryDisabler))
|
@pytest.mark.parametrize(
|
||||||
|
("enum"),
|
||||||
|
[
|
||||||
|
enum
|
||||||
|
for enum in dr.DeviceEntryDisabler
|
||||||
|
if enum != dr.DeviceEntryDisabler.DUPLICATE
|
||||||
|
],
|
||||||
|
)
|
||||||
def test_deprecated_constants(
|
def test_deprecated_constants(
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
enum: dr.DeviceEntryDisabler,
|
enum: dr.DeviceEntryDisabler,
|
||||||
@ -2903,7 +2910,27 @@ async def test_device_registry_connections_collision(
|
|||||||
|
|
||||||
device3_refetched = device_registry.async_get(device3.id)
|
device3_refetched = device_registry.async_get(device3.id)
|
||||||
device1_refetched = device_registry.async_get(device1.id)
|
device1_refetched = device_registry.async_get(device1.id)
|
||||||
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
|
|
||||||
|
# One of the devices should now:
|
||||||
|
# - Be disabled
|
||||||
|
# - Have all its connections removed
|
||||||
|
# - Have a single identifier
|
||||||
|
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
|
||||||
|
main_device = device3_refetched
|
||||||
|
duplicate_device = device1_refetched
|
||||||
|
else:
|
||||||
|
main_device = device1_refetched
|
||||||
|
duplicate_device = device3_refetched
|
||||||
|
|
||||||
|
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
|
||||||
|
assert main_device.disabled_by is None
|
||||||
|
assert duplicate_device.connections == set()
|
||||||
|
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
|
||||||
|
assert main_device.connections == {
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "ee:ee:ee:ee:ee:ee"),
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||||
|
}
|
||||||
|
assert main_device.identifiers == {("bridgeid", "0123")}
|
||||||
|
|
||||||
|
|
||||||
async def test_device_registry_identifiers_collision(
|
async def test_device_registry_identifiers_collision(
|
||||||
@ -2979,7 +3006,24 @@ async def test_device_registry_identifiers_collision(
|
|||||||
|
|
||||||
device3_refetched = device_registry.async_get(device3.id)
|
device3_refetched = device_registry.async_get(device3.id)
|
||||||
device1_refetched = device_registry.async_get(device1.id)
|
device1_refetched = device_registry.async_get(device1.id)
|
||||||
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
|
|
||||||
|
# One of the devices should now:
|
||||||
|
# - Be disabled
|
||||||
|
# - Have all its connections removed
|
||||||
|
# - Have a single identifier
|
||||||
|
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
|
||||||
|
main_device = device3_refetched
|
||||||
|
duplicate_device = device1_refetched
|
||||||
|
else:
|
||||||
|
main_device = device1_refetched
|
||||||
|
duplicate_device = device3_refetched
|
||||||
|
|
||||||
|
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
|
||||||
|
assert main_device.disabled_by is None
|
||||||
|
assert duplicate_device.connections == set()
|
||||||
|
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
|
||||||
|
assert main_device.connections == set()
|
||||||
|
assert main_device.identifiers == {("bridgeid", "0123"), ("bridgeid", "4567")}
|
||||||
|
|
||||||
|
|
||||||
async def test_primary_config_entry(
|
async def test_primary_config_entry(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user