Add ability to replace connections in DeviceRegistry (#118555)

* Add ability to replace connections in DeviceRegistry

* Add more tests

* Improve coverage

* Apply suggestion

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
epenet 2024-05-31 21:31:44 +02:00 committed by GitHub
parent bae96e7d36
commit 41e852a01b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 117 additions and 1 deletions

View File

@ -798,6 +798,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
model: str | None | UndefinedType = UNDEFINED,
name_by_user: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
new_connections: set[tuple[str, str]] | UndefinedType = UNDEFINED,
new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED,
remove_config_entry_id: str | UndefinedType = UNDEFINED,
serial_number: str | None | UndefinedType = UNDEFINED,
@ -813,6 +814,9 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
config_entries = old.config_entries
if merge_connections is not UNDEFINED and new_connections is not UNDEFINED:
raise HomeAssistantError("Cannot define both merge_connections and new_connections")
if merge_identifiers is not UNDEFINED and new_identifiers is not UNDEFINED:
raise HomeAssistantError
@ -873,6 +877,10 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
new_values[attr_name] = old_value | setvalue
old_values[attr_name] = old_value
if new_connections is not UNDEFINED:
new_values["connections"] = _normalize_connections(new_connections)
old_values["connections"] = old.connections
if new_identifiers is not UNDEFINED:
new_values["identifiers"] = new_identifiers
old_values["identifiers"] = old.identifiers

View File

@ -1257,6 +1257,7 @@ async def test_update(
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
identifiers={("hue", "456"), ("bla", "123")},
)
new_connections = {(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA")}
new_identifiers = {("hue", "654"), ("bla", "321")}
assert not entry.area_id
assert not entry.labels
@ -1275,6 +1276,7 @@ async def test_update(
model="Test Model",
name_by_user="Test Friendly Name",
name="name",
new_connections=new_connections,
new_identifiers=new_identifiers,
serial_number="serial_no",
suggested_area="suggested_area",
@ -1288,7 +1290,7 @@ async def test_update(
area_id="12345A",
config_entries={mock_config_entry.entry_id},
configuration_url="https://example.com/config",
connections={("mac", "12:34:56:ab:cd:ef")},
connections={("mac", "65:43:21:fe:dc:ba")},
disabled_by=dr.DeviceEntryDisabler.USER,
entry_type=dr.DeviceEntryType.SERVICE,
hw_version="hw_version",
@ -1319,6 +1321,12 @@ async def test_update(
device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}
)
is None
)
assert (
device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA")}
)
== updated_entry
)
@ -1336,6 +1344,7 @@ async def test_update(
"device_id": entry.id,
"changes": {
"area_id": None,
"connections": {("mac", "12:34:56:ab:cd:ef")},
"configuration_url": None,
"disabled_by": None,
"entry_type": None,
@ -1352,6 +1361,105 @@ async def test_update(
"via_device_id": None,
},
}
with pytest.raises(HomeAssistantError):
device_registry.async_update_device(
entry.id,
merge_connections=new_connections,
new_connections=new_connections,
)
with pytest.raises(HomeAssistantError):
device_registry.async_update_device(
entry.id,
merge_identifiers=new_identifiers,
new_identifiers=new_identifiers,
)
@pytest.mark.parametrize(
("initial_connections", "new_connections", "updated_connections"),
[
( # No connection -> single connection
None,
{(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
{(dr.CONNECTION_NETWORK_MAC, "12:34:56:ab:cd:ef")},
),
( # No connection -> double connection
None,
{
(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA"),
(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF"),
},
{
(dr.CONNECTION_NETWORK_MAC, "65:43:21:fe:dc:ba"),
(dr.CONNECTION_NETWORK_MAC, "12:34:56:ab:cd:ef"),
},
),
( # single connection -> no connection
{(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA")},
set(),
set(),
),
( # single connection -> single connection
{(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA")},
{(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
{(dr.CONNECTION_NETWORK_MAC, "12:34:56:ab:cd:ef")},
),
( # single connection -> double connection
{(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA")},
{
(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA"),
(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF"),
},
{
(dr.CONNECTION_NETWORK_MAC, "65:43:21:fe:dc:ba"),
(dr.CONNECTION_NETWORK_MAC, "12:34:56:ab:cd:ef"),
},
),
( # Double connection -> None
{
(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF"),
(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA"),
},
set(),
set(),
),
( # Double connection -> single connection
{
(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA"),
(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF"),
},
{(dr.CONNECTION_NETWORK_MAC, "65:43:21:FE:DC:BA")},
{(dr.CONNECTION_NETWORK_MAC, "65:43:21:fe:dc:ba")},
),
],
)
async def test_update_connection(
device_registry: dr.DeviceRegistry,
mock_config_entry: MockConfigEntry,
initial_connections: set[tuple[str, str]] | None,
new_connections: set[tuple[str, str]] | None,
updated_connections: set[tuple[str, str]] | None,
) -> None:
"""Verify that we can update some attributes of a device."""
entry = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
connections=initial_connections,
identifiers={("hue", "456"), ("bla", "123")},
)
with patch.object(device_registry, "async_schedule_save") as mock_save:
updated_entry = device_registry.async_update_device(
entry.id,
new_connections=new_connections,
)
assert mock_save.call_count == 1
assert updated_entry != entry
assert updated_entry.connections == updated_connections
assert (
device_registry.async_get_device(identifiers={("bla", "123")}) == updated_entry
)
async def test_update_remove_config_entries(