mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 09:17:10 +00:00
Validate new device identifiers and connections (#120413)
This commit is contained in:
parent
3559755aed
commit
d4e93dd01d
@ -185,6 +185,35 @@ class DeviceInfoError(HomeAssistantError):
|
||||
self.domain = domain
|
||||
|
||||
|
||||
class DeviceCollisionError(HomeAssistantError):
|
||||
"""Raised when a device collision is detected."""
|
||||
|
||||
|
||||
class DeviceIdentifierCollisionError(DeviceCollisionError):
|
||||
"""Raised when a device identifier collision is detected."""
|
||||
|
||||
def __init__(
|
||||
self, identifiers: set[tuple[str, str]], existing_device: DeviceEntry
|
||||
) -> None:
|
||||
"""Initialize error."""
|
||||
super().__init__(
|
||||
f"Identifiers {identifiers} already registered with {existing_device}"
|
||||
)
|
||||
|
||||
|
||||
class DeviceConnectionCollisionError(DeviceCollisionError):
|
||||
"""Raised when a device connection collision is detected."""
|
||||
|
||||
def __init__(
|
||||
self, normalized_connections: set[tuple[str, str]], existing_device: DeviceEntry
|
||||
) -> None:
|
||||
"""Initialize error."""
|
||||
super().__init__(
|
||||
f"Connections {normalized_connections} "
|
||||
f"already registered with {existing_device}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_device_info(
|
||||
config_entry: ConfigEntry,
|
||||
device_info: DeviceInfo,
|
||||
@ -759,6 +788,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
|
||||
device = self.async_update_device(
|
||||
device.id,
|
||||
allow_collisions=True,
|
||||
add_config_entry_id=config_entry_id,
|
||||
configuration_url=configuration_url,
|
||||
device_info_type=device_info_type,
|
||||
@ -782,11 +812,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
return device
|
||||
|
||||
@callback
|
||||
def async_update_device(
|
||||
def async_update_device( # noqa: C901
|
||||
self,
|
||||
device_id: str,
|
||||
*,
|
||||
add_config_entry_id: str | UndefinedType = UNDEFINED,
|
||||
# Temporary flag so we don't blow up when collisions are implicitly introduced
|
||||
# by calls to async_get_or_create. Must not be set by integrations.
|
||||
allow_collisions: bool = False,
|
||||
area_id: str | None | UndefinedType = UNDEFINED,
|
||||
configuration_url: str | URL | None | UndefinedType = UNDEFINED,
|
||||
device_info_type: str | UndefinedType = UNDEFINED,
|
||||
@ -894,12 +927,36 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
new_values[attr_name] = old_value | setvalue
|
||||
old_values[attr_name] = old_value
|
||||
|
||||
if merge_connections is not UNDEFINED:
|
||||
normalized_connections = self._validate_connections(
|
||||
device_id,
|
||||
merge_connections,
|
||||
allow_collisions,
|
||||
)
|
||||
old_connections = old.connections
|
||||
if not normalized_connections.issubset(old_connections):
|
||||
new_values["connections"] = old_connections | normalized_connections
|
||||
old_values["connections"] = old_connections
|
||||
|
||||
if merge_identifiers is not UNDEFINED:
|
||||
merge_identifiers = self._validate_identifiers(
|
||||
device_id, merge_identifiers, allow_collisions
|
||||
)
|
||||
old_identifiers = old.identifiers
|
||||
if not merge_identifiers.issubset(old_identifiers):
|
||||
new_values["identifiers"] = old_identifiers | merge_identifiers
|
||||
old_values["identifiers"] = old_identifiers
|
||||
|
||||
if new_connections is not UNDEFINED:
|
||||
new_values["connections"] = _normalize_connections(new_connections)
|
||||
new_values["connections"] = self._validate_connections(
|
||||
device_id, new_connections, False
|
||||
)
|
||||
old_values["connections"] = old.connections
|
||||
|
||||
if new_identifiers is not UNDEFINED:
|
||||
new_values["identifiers"] = new_identifiers
|
||||
new_values["identifiers"] = self._validate_identifiers(
|
||||
device_id, new_identifiers, False
|
||||
)
|
||||
old_values["identifiers"] = old.identifiers
|
||||
|
||||
if configuration_url is not UNDEFINED:
|
||||
@ -955,6 +1012,53 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
||||
|
||||
return new
|
||||
|
||||
@callback
|
||||
def _validate_connections(
|
||||
self,
|
||||
device_id: str,
|
||||
connections: set[tuple[str, str]],
|
||||
allow_collisions: bool,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""Normalize and validate connections, raise on collision with other devices."""
|
||||
normalized_connections = _normalize_connections(connections)
|
||||
if allow_collisions:
|
||||
return normalized_connections
|
||||
|
||||
for connection in normalized_connections:
|
||||
# We need to iterate over each connection because if there is a
|
||||
# conflict, the index will only see the last one and we will not
|
||||
# be able to tell which one caused the conflict
|
||||
if (
|
||||
existing_device := self.async_get_device(connections={connection})
|
||||
) and existing_device.id != device_id:
|
||||
raise DeviceConnectionCollisionError(
|
||||
normalized_connections, existing_device
|
||||
)
|
||||
|
||||
return normalized_connections
|
||||
|
||||
@callback
|
||||
def _validate_identifiers(
|
||||
self,
|
||||
device_id: str,
|
||||
identifiers: set[tuple[str, str]],
|
||||
allow_collisions: bool,
|
||||
) -> set[tuple[str, str]]:
|
||||
"""Validate identifiers, raise on collision with other devices."""
|
||||
if allow_collisions:
|
||||
return identifiers
|
||||
|
||||
for identifier in identifiers:
|
||||
# We need to iterate over each identifier because if there is a
|
||||
# conflict, the index will only see the last one and we will not
|
||||
# be able to tell which one caused the conflict
|
||||
if (
|
||||
existing_device := self.async_get_device(identifiers={identifier})
|
||||
) and existing_device.id != device_id:
|
||||
raise DeviceIdentifierCollisionError(identifiers, existing_device)
|
||||
|
||||
return identifiers
|
||||
|
||||
@callback
|
||||
def async_remove_device(self, device_id: str) -> None:
|
||||
"""Remove a device from the device registry."""
|
||||
|
@ -2630,3 +2630,167 @@ async def test_async_remove_device_thread_safety(
|
||||
await hass.async_add_executor_job(
|
||||
device_registry.async_remove_device, device.id
|
||||
)
|
||||
|
||||
|
||||
async def test_device_registry_connections_collision(
|
||||
hass: HomeAssistant, device_registry: dr.DeviceRegistry
|
||||
) -> None:
|
||||
"""Test connection collisions in the device registry."""
|
||||
config_entry = MockConfigEntry()
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
device1 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
device2 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
|
||||
assert device1.id == device2.id
|
||||
|
||||
device3 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
|
||||
# Attempt to merge connection for device3 with the same
|
||||
# connection that already exists in device1
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match=f"Connections.*already registered.*{device1.id}"
|
||||
):
|
||||
device_registry.async_update_device(
|
||||
device3.id,
|
||||
merge_connections={
|
||||
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||
},
|
||||
)
|
||||
|
||||
# Attempt to add new connections for device3 with the same
|
||||
# connection that already exists in device1
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match=f"Connections.*already registered.*{device1.id}"
|
||||
):
|
||||
device_registry.async_update_device(
|
||||
device3.id,
|
||||
new_connections={
|
||||
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||
},
|
||||
)
|
||||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
assert device3_refetched.connections == set()
|
||||
assert device3_refetched.identifiers == {("bridgeid", "0123")}
|
||||
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert device1_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")}
|
||||
assert device1_refetched.identifiers == set()
|
||||
|
||||
device2_refetched = device_registry.async_get(device2.id)
|
||||
assert device2_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")}
|
||||
assert device2_refetched.identifiers == set()
|
||||
|
||||
assert device2_refetched.id == device1_refetched.id
|
||||
assert len(device_registry.devices) == 2
|
||||
|
||||
# Attempt to implicitly merge connection for device3 with the same
|
||||
# connection that already exists in device1
|
||||
device4 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
connections={
|
||||
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||
},
|
||||
)
|
||||
assert len(device_registry.devices) == 2
|
||||
assert device4.id in (device1.id, device3.id)
|
||||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
|
||||
|
||||
|
||||
async def test_device_registry_identifiers_collision(
|
||||
hass: HomeAssistant, device_registry: dr.DeviceRegistry
|
||||
) -> None:
|
||||
"""Test identifiers collisions in the device registry."""
|
||||
config_entry = MockConfigEntry()
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
device1 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
device2 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
|
||||
assert device1.id == device2.id
|
||||
|
||||
device3 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("bridgeid", "4567")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
|
||||
# Attempt to merge identifiers for device3 with the same
|
||||
# connection that already exists in device1
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}"
|
||||
):
|
||||
device_registry.async_update_device(
|
||||
device3.id, merge_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")}
|
||||
)
|
||||
|
||||
# Attempt to add new identifiers for device3 with the same
|
||||
# connection that already exists in device1
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}"
|
||||
):
|
||||
device_registry.async_update_device(
|
||||
device3.id, new_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")}
|
||||
)
|
||||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
assert device3_refetched.connections == set()
|
||||
assert device3_refetched.identifiers == {("bridgeid", "4567")}
|
||||
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert device1_refetched.connections == set()
|
||||
assert device1_refetched.identifiers == {("bridgeid", "0123")}
|
||||
|
||||
device2_refetched = device_registry.async_get(device2.id)
|
||||
assert device2_refetched.connections == set()
|
||||
assert device2_refetched.identifiers == {("bridgeid", "0123")}
|
||||
|
||||
assert device2_refetched.id == device1_refetched.id
|
||||
assert len(device_registry.devices) == 2
|
||||
|
||||
# Attempt to implicitly merge identifiers for device3 with the same
|
||||
# connection that already exists in device1
|
||||
device4 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={("bridgeid", "4567"), ("bridgeid", "0123")},
|
||||
)
|
||||
assert len(device_registry.devices) == 2
|
||||
assert device4.id in (device1.id, device3.id)
|
||||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
|
||||
|
Loading…
x
Reference in New Issue
Block a user