From d4e93dd01dc076fcb3a54cb7228537ade02b532a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 25 Jun 2024 19:17:54 +0200 Subject: [PATCH] Validate new device identifiers and connections (#120413) --- homeassistant/helpers/device_registry.py | 110 ++++++++++++++- tests/helpers/test_device_registry.py | 164 +++++++++++++++++++++++ 2 files changed, 271 insertions(+), 3 deletions(-) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 2a90d885d70..36249733f71 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -185,6 +185,35 @@ class DeviceInfoError(HomeAssistantError): self.domain = domain +class DeviceCollisionError(HomeAssistantError): + """Raised when a device collision is detected.""" + + +class DeviceIdentifierCollisionError(DeviceCollisionError): + """Raised when a device identifier collision is detected.""" + + def __init__( + self, identifiers: set[tuple[str, str]], existing_device: DeviceEntry + ) -> None: + """Initialize error.""" + super().__init__( + f"Identifiers {identifiers} already registered with {existing_device}" + ) + + +class DeviceConnectionCollisionError(DeviceCollisionError): + """Raised when a device connection collision is detected.""" + + def __init__( + self, normalized_connections: set[tuple[str, str]], existing_device: DeviceEntry + ) -> None: + """Initialize error.""" + super().__init__( + f"Connections {normalized_connections} " + f"already registered with {existing_device}" + ) + + def _validate_device_info( config_entry: ConfigEntry, device_info: DeviceInfo, @@ -759,6 +788,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): device = self.async_update_device( device.id, + allow_collisions=True, add_config_entry_id=config_entry_id, configuration_url=configuration_url, device_info_type=device_info_type, @@ -782,11 +812,14 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): return device @callback - def async_update_device( + def async_update_device( # noqa: C901 self, device_id: str, *, add_config_entry_id: str | UndefinedType = UNDEFINED, + # Temporary flag so we don't blow up when collisions are implicitly introduced + # by calls to async_get_or_create. Must not be set by integrations. + allow_collisions: bool = False, area_id: str | None | UndefinedType = UNDEFINED, configuration_url: str | URL | None | UndefinedType = UNDEFINED, device_info_type: str | UndefinedType = UNDEFINED, @@ -894,12 +927,36 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): new_values[attr_name] = old_value | setvalue old_values[attr_name] = old_value + if merge_connections is not UNDEFINED: + normalized_connections = self._validate_connections( + device_id, + merge_connections, + allow_collisions, + ) + old_connections = old.connections + if not normalized_connections.issubset(old_connections): + new_values["connections"] = old_connections | normalized_connections + old_values["connections"] = old_connections + + if merge_identifiers is not UNDEFINED: + merge_identifiers = self._validate_identifiers( + device_id, merge_identifiers, allow_collisions + ) + old_identifiers = old.identifiers + if not merge_identifiers.issubset(old_identifiers): + new_values["identifiers"] = old_identifiers | merge_identifiers + old_values["identifiers"] = old_identifiers + if new_connections is not UNDEFINED: - new_values["connections"] = _normalize_connections(new_connections) + new_values["connections"] = self._validate_connections( + device_id, new_connections, False + ) old_values["connections"] = old.connections if new_identifiers is not UNDEFINED: - new_values["identifiers"] = new_identifiers + new_values["identifiers"] = self._validate_identifiers( + device_id, new_identifiers, False + ) old_values["identifiers"] = old.identifiers if configuration_url is not UNDEFINED: @@ -955,6 +1012,53 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): return new + @callback + def _validate_connections( + self, + device_id: str, + connections: set[tuple[str, str]], + allow_collisions: bool, + ) -> set[tuple[str, str]]: + """Normalize and validate connections, raise on collision with other devices.""" + normalized_connections = _normalize_connections(connections) + if allow_collisions: + return normalized_connections + + for connection in normalized_connections: + # We need to iterate over each connection because if there is a + # conflict, the index will only see the last one and we will not + # be able to tell which one caused the conflict + if ( + existing_device := self.async_get_device(connections={connection}) + ) and existing_device.id != device_id: + raise DeviceConnectionCollisionError( + normalized_connections, existing_device + ) + + return normalized_connections + + @callback + def _validate_identifiers( + self, + device_id: str, + identifiers: set[tuple[str, str]], + allow_collisions: bool, + ) -> set[tuple[str, str]]: + """Validate identifiers, raise on collision with other devices.""" + if allow_collisions: + return identifiers + + for identifier in identifiers: + # We need to iterate over each identifier because if there is a + # conflict, the index will only see the last one and we will not + # be able to tell which one caused the conflict + if ( + existing_device := self.async_get_device(identifiers={identifier}) + ) and existing_device.id != device_id: + raise DeviceIdentifierCollisionError(identifiers, existing_device) + + return identifiers + @callback def async_remove_device(self, device_id: str) -> None: """Remove a device from the device registry.""" diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index b141e29f678..f8f10baad08 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -2630,3 +2630,167 @@ async def test_async_remove_device_thread_safety( await hass.async_add_executor_job( device_registry.async_remove_device, device.id ) + + +async def test_device_registry_connections_collision( + hass: HomeAssistant, device_registry: dr.DeviceRegistry +) -> None: + """Test connection collisions in the device registry.""" + config_entry = MockConfigEntry() + config_entry.add_to_hass(hass) + + device1 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "none")}, + manufacturer="manufacturer", + model="model", + ) + device2 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "none")}, + manufacturer="manufacturer", + model="model", + ) + + assert device1.id == device2.id + + device3 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "0123")}, + manufacturer="manufacturer", + model="model", + ) + + # Attempt to merge connection for device3 with the same + # connection that already exists in device1 + with pytest.raises( + HomeAssistantError, match=f"Connections.*already registered.*{device1.id}" + ): + device_registry.async_update_device( + device3.id, + merge_connections={ + (dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"), + (dr.CONNECTION_NETWORK_MAC, "none"), + }, + ) + + # Attempt to add new connections for device3 with the same + # connection that already exists in device1 + with pytest.raises( + HomeAssistantError, match=f"Connections.*already registered.*{device1.id}" + ): + device_registry.async_update_device( + device3.id, + new_connections={ + (dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"), + (dr.CONNECTION_NETWORK_MAC, "none"), + }, + ) + + device3_refetched = device_registry.async_get(device3.id) + assert device3_refetched.connections == set() + assert device3_refetched.identifiers == {("bridgeid", "0123")} + + device1_refetched = device_registry.async_get(device1.id) + assert device1_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")} + assert device1_refetched.identifiers == set() + + device2_refetched = device_registry.async_get(device2.id) + assert device2_refetched.connections == {(dr.CONNECTION_NETWORK_MAC, "none")} + assert device2_refetched.identifiers == set() + + assert device2_refetched.id == device1_refetched.id + assert len(device_registry.devices) == 2 + + # Attempt to implicitly merge connection for device3 with the same + # connection that already exists in device1 + device4 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "0123")}, + connections={ + (dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"), + (dr.CONNECTION_NETWORK_MAC, "none"), + }, + ) + assert len(device_registry.devices) == 2 + assert device4.id in (device1.id, device3.id) + + device3_refetched = device_registry.async_get(device3.id) + device1_refetched = device_registry.async_get(device1.id) + assert not device1_refetched.connections.isdisjoint(device3_refetched.connections) + + +async def test_device_registry_identifiers_collision( + hass: HomeAssistant, device_registry: dr.DeviceRegistry +) -> None: + """Test identifiers collisions in the device registry.""" + config_entry = MockConfigEntry() + config_entry.add_to_hass(hass) + + device1 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "0123")}, + manufacturer="manufacturer", + model="model", + ) + device2 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "0123")}, + manufacturer="manufacturer", + model="model", + ) + + assert device1.id == device2.id + + device3 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "4567")}, + manufacturer="manufacturer", + model="model", + ) + + # Attempt to merge identifiers for device3 with the same + # connection that already exists in device1 + with pytest.raises( + HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}" + ): + device_registry.async_update_device( + device3.id, merge_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")} + ) + + # Attempt to add new identifiers for device3 with the same + # connection that already exists in device1 + with pytest.raises( + HomeAssistantError, match=f"Identifiers.*already registered.*{device1.id}" + ): + device_registry.async_update_device( + device3.id, new_identifiers={("bridgeid", "0123"), ("bridgeid", "8888")} + ) + + device3_refetched = device_registry.async_get(device3.id) + assert device3_refetched.connections == set() + assert device3_refetched.identifiers == {("bridgeid", "4567")} + + device1_refetched = device_registry.async_get(device1.id) + assert device1_refetched.connections == set() + assert device1_refetched.identifiers == {("bridgeid", "0123")} + + device2_refetched = device_registry.async_get(device2.id) + assert device2_refetched.connections == set() + assert device2_refetched.identifiers == {("bridgeid", "0123")} + + assert device2_refetched.id == device1_refetched.id + assert len(device_registry.devices) == 2 + + # Attempt to implicitly merge identifiers for device3 with the same + # connection that already exists in device1 + device4 = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + identifiers={("bridgeid", "4567"), ("bridgeid", "0123")}, + ) + assert len(device_registry.devices) == 2 + assert device4.id in (device1.id, device3.id) + + device3_refetched = device_registry.async_get(device3.id) + device1_refetched = device_registry.async_get(device1.id) + assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)