mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 09:17:10 +00:00
Validate new device identifiers and connections (#120413)
This commit is contained in:
parent
3559755aed
commit
d4e93dd01d
@ -185,6 +185,35 @@ class DeviceInfoError(HomeAssistantError):
|
|||||||
self.domain = domain
|
self.domain = domain
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceCollisionError(HomeAssistantError):
|
||||||
|
"""Raised when a device collision is detected."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceIdentifierCollisionError(DeviceCollisionError):
|
||||||
|
"""Raised when a device identifier collision is detected."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, identifiers: set[tuple[str, str]], existing_device: DeviceEntry
|
||||||
|
) -> None:
|
||||||
|
"""Initialize error."""
|
||||||
|
super().__init__(
|
||||||
|
f"Identifiers {identifiers} already registered with {existing_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceConnectionCollisionError(DeviceCollisionError):
|
||||||
|
"""Raised when a device connection collision is detected."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, normalized_connections: set[tuple[str, str]], existing_device: DeviceEntry
|
||||||
|
) -> None:
|
||||||
|
"""Initialize error."""
|
||||||
|
super().__init__(
|
||||||
|
f"Connections {normalized_connections} "
|
||||||
|
f"already registered with {existing_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_device_info(
|
def _validate_device_info(
|
||||||
config_entry: ConfigEntry,
|
config_entry: ConfigEntry,
|
||||||
device_info: DeviceInfo,
|
device_info: DeviceInfo,
|
||||||
@ -759,6 +788,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
|
|
||||||
device = self.async_update_device(
|
device = self.async_update_device(
|
||||||
device.id,
|
device.id,
|
||||||
|
allow_collisions=True,
|
||||||
add_config_entry_id=config_entry_id,
|
add_config_entry_id=config_entry_id,
|
||||||
configuration_url=configuration_url,
|
configuration_url=configuration_url,
|
||||||
device_info_type=device_info_type,
|
device_info_type=device_info_type,
|
||||||
@ -782,11 +812,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
return device
|
return device
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_update_device(
|
def async_update_device( # noqa: C901
|
||||||
self,
|
self,
|
||||||
device_id: str,
|
device_id: str,
|
||||||
*,
|
*,
|
||||||
add_config_entry_id: str | UndefinedType = UNDEFINED,
|
add_config_entry_id: str | UndefinedType = UNDEFINED,
|
||||||
|
# Temporary flag so we don't blow up when collisions are implicitly introduced
|
||||||
|
# by calls to async_get_or_create. Must not be set by integrations.
|
||||||
|
allow_collisions: bool = False,
|
||||||
area_id: str | None | UndefinedType = UNDEFINED,
|
area_id: str | None | UndefinedType = UNDEFINED,
|
||||||
configuration_url: str | URL | None | UndefinedType = UNDEFINED,
|
configuration_url: str | URL | None | UndefinedType = UNDEFINED,
|
||||||
device_info_type: str | UndefinedType = UNDEFINED,
|
device_info_type: str | UndefinedType = UNDEFINED,
|
||||||
@ -894,12 +927,36 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
new_values[attr_name] = old_value | setvalue
|
new_values[attr_name] = old_value | setvalue
|
||||||
old_values[attr_name] = old_value
|
old_values[attr_name] = old_value
|
||||||
|
|
||||||
|
if merge_connections is not UNDEFINED:
|
||||||
|
normalized_connections = self._validate_connections(
|
||||||
|
device_id,
|
||||||
|
merge_connections,
|
||||||
|
allow_collisions,
|
||||||
|
)
|
||||||
|
old_connections = old.connections
|
||||||
|
if not normalized_connections.issubset(old_connections):
|
||||||
|
new_values["connections"] = old_connections | normalized_connections
|
||||||
|
old_values["connections"] = old_connections
|
||||||
|
|
||||||
|
if merge_identifiers is not UNDEFINED:
|
||||||
|
merge_identifiers = self._validate_identifiers(
|
||||||
|
device_id, merge_identifiers, allow_collisions
|
||||||
|
)
|
||||||
|
old_identifiers = old.identifiers
|
||||||
|
if not merge_identifiers.issubset(old_identifiers):
|
||||||
|
new_values["identifiers"] = old_identifiers | merge_identifiers
|
||||||
|
old_values["identifiers"] = old_identifiers
|
||||||
|
|
||||||
if new_connections is not UNDEFINED:
|
if new_connections is not UNDEFINED:
|
||||||
new_values["connections"] = _normalize_connections(new_connections)
|
new_values["connections"] = self._validate_connections(
|
||||||
|
device_id, new_connections, False
|
||||||
|
)
|
||||||
old_values["connections"] = old.connections
|
old_values["connections"] = old.connections
|
||||||
|
|
||||||
if new_identifiers is not UNDEFINED:
|
if new_identifiers is not UNDEFINED:
|
||||||
new_values["identifiers"] = new_identifiers
|
new_values["identifiers"] = self._validate_identifiers(
|
||||||
|
device_id, new_identifiers, False
|
||||||
|
)
|
||||||
old_values["identifiers"] = old.identifiers
|
old_values["identifiers"] = old.identifiers
|
||||||
|
|
||||||
if configuration_url is not UNDEFINED:
|
if configuration_url is not UNDEFINED:
|
||||||
@ -955,6 +1012,53 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _validate_connections(
|
||||||
|
self,
|
||||||
|
device_id: str,
|
||||||
|
connections: set[tuple[str, str]],
|
||||||
|
allow_collisions: bool,
|
||||||
|
) -> set[tuple[str, str]]:
|
||||||
|
"""Normalize and validate connections, raise on collision with other devices."""
|
||||||
|
normalized_connections = _normalize_connections(connections)
|
||||||
|
if allow_collisions:
|
||||||
|
return normalized_connections
|
||||||
|
|
||||||
|
for connection in normalized_connections:
|
||||||
|
# We need to iterate over each connection because if there is a
|
||||||
|
# conflict, the index will only see the last one and we will not
|
||||||
|
# be able to tell which one caused the conflict
|
||||||
|
if (
|
||||||
|
existing_device := self.async_get_device(connections={connection})
|
||||||
|
) and existing_device.id != device_id:
|
||||||
|
raise DeviceConnectionCollisionError(
|
||||||
|
normalized_connections, existing_device
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized_connections
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _validate_identifiers(
|
||||||
|
self,
|
||||||
|
device_id: str,
|
||||||
|
identifiers: set[tuple[str, str]],
|
||||||
|
allow_collisions: bool,
|
||||||
|
) -> set[tuple[str, str]]:
|
||||||
|
"""Validate identifiers, raise on collision with other devices."""
|
||||||
|
if allow_collisions:
|
||||||
|
return identifiers
|
||||||
|
|
||||||
|
for identifier in identifiers:
|
||||||
|
# We need to iterate over each identifier because if there is a
|
||||||
|
# conflict, the index will only see the last one and we will not
|
||||||
|
# be able to tell which one caused the conflict
|
||||||
|
if (
|
||||||
|
existing_device := self.async_get_device(identifiers={identifier})
|
||||||
|
) and existing_device.id != device_id:
|
||||||
|
raise DeviceIdentifierCollisionError(identifiers, existing_device)
|
||||||
|
|
||||||
|
return identifiers
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove_device(self, device_id: str) -> None:
|
def async_remove_device(self, device_id: str) -> None:
|
||||||
"""Remove a device from the device registry."""
|
"""Remove a device from the device registry."""
|
||||||
|
@ -2630,3 +2630,167 @@ async def test_async_remove_device_thread_safety(
|
|||||||
await hass.async_add_executor_job(
|
await hass.async_add_executor_job(
|
||||||
device_registry.async_remove_device, device.id
|
device_registry.async_remove_device, device.id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_registry_connections_collision(
|
||||||
|
hass: HomeAssistant, device_registry: dr.DeviceRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test connection collisions in the device registry."""
|
||||||
|
config_entry = MockConfigEntry()
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
device1 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
device2 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert device1.id == device2.id
|
||||||
|
|
||||||
|
device3 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={("bridgeid", "0123")},
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to merge connection for device3 with the same
|
||||||
|
# connection that already exists in device1
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match=f"Connections.*already registered.*{device1.id}"
|
||||||
|
):
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device3.id,
|
||||||
|
merge_connections={
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to add new connections for device3 with the same
|
||||||
|
# connection that already exists in device1
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match=f"Connections.*already registered.*{device1.id}"
|
||||||
|
):
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device3.id,
|
||||||
|
new_connections={
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
device3_refetched = device_registry.async_get(device3.id)
|
||||||
|
assert device3_refetched.connections == set()
|
||||||
|
assert device3_refetched.identifiers == {("bridgeid", "0123")}
|
||||||
|
|
||||||
|
device1_refetched = device_registry.async_get(device1.id)
|
||||||
|
assert device1_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")}
|
||||||
|
assert device1_refetched.identifiers == set()
|
||||||
|
|
||||||
|
device2_refetched = device_registry.async_get(device2.id)
|
||||||
|
assert device2_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")}
|
||||||
|
assert device2_refetched.identifiers == set()
|
||||||
|
|
||||||
|
assert device2_refetched.id == device1_refetched.id
|
||||||
|
assert len(device_registry.devices) == 2
|
||||||
|
|
||||||
|
# Attempt to implicitly merge connection for device3 with the same
|
||||||
|
# connection that already exists in device1
|
||||||
|
device4 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={("bridgeid", "0123")},
|
||||||
|
connections={
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||||
|
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert len(device_registry.devices) == 2
|
||||||
|
assert device4.id in (device1.id, device3.id)
|
||||||
|
|
||||||
|
device3_refetched = device_registry.async_get(device3.id)
|
||||||
|
device1_refetched = device_registry.async_get(device1.id)
|
||||||
|
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_registry_identifiers_collision(
|
||||||
|
hass: HomeAssistant, device_registry: dr.DeviceRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test identifiers collisions in the device registry."""
|
||||||
|
config_entry = MockConfigEntry()
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
device1 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={("bridgeid", "0123")},
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
device2 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={("bridgeid", "0123")},
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert device1.id == device2.id
|
||||||
|
|
||||||
|
device3 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={("bridgeid", "4567")},
|
||||||
|
manufacturer="manufacturer",
|
||||||
|
model="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to merge identifiers for device3 with the same
|
||||||
|
# connection that already exists in device1
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}"
|
||||||
|
):
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device3.id, merge_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to add new identifiers for device3 with the same
|
||||||
|
# connection that already exists in device1
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}"
|
||||||
|
):
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device3.id, new_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")}
|
||||||
|
)
|
||||||
|
|
||||||
|
device3_refetched = device_registry.async_get(device3.id)
|
||||||
|
assert device3_refetched.connections == set()
|
||||||
|
assert device3_refetched.identifiers == {("bridgeid", "4567")}
|
||||||
|
|
||||||
|
device1_refetched = device_registry.async_get(device1.id)
|
||||||
|
assert device1_refetched.connections == set()
|
||||||
|
assert device1_refetched.identifiers == {("bridgeid", "0123")}
|
||||||
|
|
||||||
|
device2_refetched = device_registry.async_get(device2.id)
|
||||||
|
assert device2_refetched.connections == set()
|
||||||
|
assert device2_refetched.identifiers == {("bridgeid", "0123")}
|
||||||
|
|
||||||
|
assert device2_refetched.id == device1_refetched.id
|
||||||
|
assert len(device_registry.devices) == 2
|
||||||
|
|
||||||
|
# Attempt to implicitly merge identifiers for device3 with the same
|
||||||
|
# connection that already exists in device1
|
||||||
|
device4 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={("bridgeid", "4567"), ("bridgeid", "0123")},
|
||||||
|
)
|
||||||
|
assert len(device_registry.devices) == 2
|
||||||
|
assert device4.id in (device1.id, device3.id)
|
||||||
|
|
||||||
|
device3_refetched = device_registry.async_get(device3.id)
|
||||||
|
device1_refetched = device_registry.async_get(device1.id)
|
||||||
|
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user