Validate new device identifiers and connections (#120413)

This commit is contained in:
Erik Montnemery 2024-06-25 19:17:54 +02:00 committed by GitHub
parent 3559755aed
commit d4e93dd01d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 271 additions and 3 deletions

View File

@ -185,6 +185,35 @@ class DeviceInfoError(HomeAssistantError):
self.domain = domain
class DeviceCollisionError(HomeAssistantError):
"""Raised when a device collision is detected."""
class DeviceIdentifierCollisionError(DeviceCollisionError):
"""Raised when a device identifier collision is detected."""
def __init__(
self, identifiers: set[tuple[str, str]], existing_device: DeviceEntry
) -> None:
"""Initialize error."""
super().__init__(
f"Identifiers {identifiers} already registered with {existing_device}"
)
class DeviceConnectionCollisionError(DeviceCollisionError):
"""Raised when a device connection collision is detected."""
def __init__(
self, normalized_connections: set[tuple[str, str]], existing_device: DeviceEntry
) -> None:
"""Initialize error."""
super().__init__(
f"Connections {normalized_connections} "
f"already registered with {existing_device}"
)
def _validate_device_info(
config_entry: ConfigEntry,
device_info: DeviceInfo,
@ -759,6 +788,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
device = self.async_update_device(
device.id,
allow_collisions=True,
add_config_entry_id=config_entry_id,
configuration_url=configuration_url,
device_info_type=device_info_type,
@ -782,11 +812,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
return device
@callback
def async_update_device(
def async_update_device( # noqa: C901
self,
device_id: str,
*,
add_config_entry_id: str | UndefinedType = UNDEFINED,
# Temporary flag so we don't blow up when collisions are implicitly introduced
# by calls to async_get_or_create. Must not be set by integrations.
allow_collisions: bool = False,
area_id: str | None | UndefinedType = UNDEFINED,
configuration_url: str | URL | None | UndefinedType = UNDEFINED,
device_info_type: str | UndefinedType = UNDEFINED,
@ -894,12 +927,36 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
new_values[attr_name] = old_value | setvalue
old_values[attr_name] = old_value
if merge_connections is not UNDEFINED:
normalized_connections = self._validate_connections(
device_id,
merge_connections,
allow_collisions,
)
old_connections = old.connections
if not normalized_connections.issubset(old_connections):
new_values["connections"] = old_connections | normalized_connections
old_values["connections"] = old_connections
if merge_identifiers is not UNDEFINED:
merge_identifiers = self._validate_identifiers(
device_id, merge_identifiers, allow_collisions
)
old_identifiers = old.identifiers
if not merge_identifiers.issubset(old_identifiers):
new_values["identifiers"] = old_identifiers | merge_identifiers
old_values["identifiers"] = old_identifiers
if new_connections is not UNDEFINED:
new_values["connections"] = _normalize_connections(new_connections)
new_values["connections"] = self._validate_connections(
device_id, new_connections, False
)
old_values["connections"] = old.connections
if new_identifiers is not UNDEFINED:
new_values["identifiers"] = new_identifiers
new_values["identifiers"] = self._validate_identifiers(
device_id, new_identifiers, False
)
old_values["identifiers"] = old.identifiers
if configuration_url is not UNDEFINED:
@ -955,6 +1012,53 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
return new
@callback
def _validate_connections(
self,
device_id: str,
connections: set[tuple[str, str]],
allow_collisions: bool,
) -> set[tuple[str, str]]:
"""Normalize and validate connections, raise on collision with other devices."""
normalized_connections = _normalize_connections(connections)
if allow_collisions:
return normalized_connections
for connection in normalized_connections:
# We need to iterate over each connection because if there is a
# conflict, the index will only see the last one and we will not
# be able to tell which one caused the conflict
if (
existing_device := self.async_get_device(connections={connection})
) and existing_device.id != device_id:
raise DeviceConnectionCollisionError(
normalized_connections, existing_device
)
return normalized_connections
@callback
def _validate_identifiers(
self,
device_id: str,
identifiers: set[tuple[str, str]],
allow_collisions: bool,
) -> set[tuple[str, str]]:
"""Validate identifiers, raise on collision with other devices."""
if allow_collisions:
return identifiers
for identifier in identifiers:
# We need to iterate over each identifier because if there is a
# conflict, the index will only see the last one and we will not
# be able to tell which one caused the conflict
if (
existing_device := self.async_get_device(identifiers={identifier})
) and existing_device.id != device_id:
raise DeviceIdentifierCollisionError(identifiers, existing_device)
return identifiers
@callback
def async_remove_device(self, device_id: str) -> None:
"""Remove a device from the device registry."""

View File

@ -2630,3 +2630,167 @@ async def test_async_remove_device_thread_safety(
await hass.async_add_executor_job(
device_registry.async_remove_device, device.id
)
async def test_device_registry_connections_collision(
hass: HomeAssistant, device_registry: dr.DeviceRegistry
) -> None:
"""Test connection collisions in the device registry."""
config_entry = MockConfigEntry()
config_entry.add_to_hass(hass)
device1 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
manufacturer="manufacturer",
model="model",
)
device2 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
manufacturer="manufacturer",
model="model",
)
assert device1.id == device2.id
device3 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
)
# Attempt to merge connection for device3 with the same
# connection that already exists in device1
with pytest.raises(
HomeAssistantError, match=f"Connections.*already registered.*{device1.id}"
):
device_registry.async_update_device(
device3.id,
merge_connections={
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
(dr.CONNECTION_NETWORK_MAC, "none"),
},
)
# Attempt to add new connections for device3 with the same
# connection that already exists in device1
with pytest.raises(
HomeAssistantError, match=f"Connections.*already registered.*{device1.id}"
):
device_registry.async_update_device(
device3.id,
new_connections={
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
(dr.CONNECTION_NETWORK_MAC, "none"),
},
)
device3_refetched = device_registry.async_get(device3.id)
assert device3_refetched.connections == set()
assert device3_refetched.identifiers == {("bridgeid", "0123")}
device1_refetched = device_registry.async_get(device1.id)
assert device1_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")}
assert device1_refetched.identifiers == set()
device2_refetched = device_registry.async_get(device2.id)
assert device2_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")}
assert device2_refetched.identifiers == set()
assert device2_refetched.id == device1_refetched.id
assert len(device_registry.devices) == 2
# Attempt to implicitly merge connection for device3 with the same
# connection that already exists in device1
device4 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "0123")},
connections={
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
(dr.CONNECTION_NETWORK_MAC, "none"),
},
)
assert len(device_registry.devices) == 2
assert device4.id in (device1.id, device3.id)
device3_refetched = device_registry.async_get(device3.id)
device1_refetched = device_registry.async_get(device1.id)
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
async def test_device_registry_identifiers_collision(
hass: HomeAssistant, device_registry: dr.DeviceRegistry
) -> None:
"""Test identifiers collisions in the device registry."""
config_entry = MockConfigEntry()
config_entry.add_to_hass(hass)
device1 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
)
device2 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
)
assert device1.id == device2.id
device3 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "4567")},
manufacturer="manufacturer",
model="model",
)
# Attempt to merge identifiers for device3 with the same
# connection that already exists in device1
with pytest.raises(
HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}"
):
device_registry.async_update_device(
device3.id, merge_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")}
)
# Attempt to add new identifiers for device3 with the same
# connection that already exists in device1
with pytest.raises(
HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}"
):
device_registry.async_update_device(
device3.id, new_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")}
)
device3_refetched = device_registry.async_get(device3.id)
assert device3_refetched.connections == set()
assert device3_refetched.identifiers == {("bridgeid", "4567")}
device1_refetched = device_registry.async_get(device1.id)
assert device1_refetched.connections == set()
assert device1_refetched.identifiers == {("bridgeid", "0123")}
device2_refetched = device_registry.async_get(device2.id)
assert device2_refetched.connections == set()
assert device2_refetched.identifiers == {("bridgeid", "0123")}
assert device2_refetched.id == device1_refetched.id
assert len(device_registry.devices) == 2
# Attempt to implicitly merge identifiers for device3 with the same
# connection that already exists in device1
device4 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("bridgeid", "4567"), ("bridgeid", "0123")},
)
assert len(device_registry.devices) == 2
assert device4.id in (device1.id, device3.id)
device3_refetched = device_registry.async_get(device3.id)
device1_refetched = device_registry.async_get(device1.id)
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)