From 51b8ffc69d2744f171e2132573bafabd0f8494c2 Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Mon, 18 Mar 2024 21:19:27 +0100 Subject: [PATCH] Add WebSocket support for handling labels on device registry (#113758) --- .../components/config/device_registry.py | 5 +++ homeassistant/helpers/device_registry.py | 1 + .../components/config/test_device_registry.py | 43 +++++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index d7a9dc1a66d..f2b0035d060 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -69,6 +69,7 @@ def websocket_list_devices( # We only allow setting disabled_by user via API. # No Enum support like this in voluptuous, use .value vol.Optional("disabled_by"): vol.Any(DeviceEntryDisabler.USER.value, None), + vol.Optional("labels"): [str], vol.Optional("name_by_user"): vol.Any(str, None), } ) @@ -87,6 +88,10 @@ def websocket_update_device( if msg.get("disabled_by") is not None: msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"]) + if "labels" in msg: + # Convert labels to a set + msg["labels"] = set(msg["labels"]) + entry = cast(DeviceEntry, registry.async_update_device(**msg)) connection.send_message(websocket_api.result_message(msg_id, entry.dict_repr)) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index ba4c393a2d9..2f421034919 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -276,6 +276,7 @@ class DeviceEntry: "hw_version": self.hw_version, "id": self.id, "identifiers": list(self.identifiers), + "labels": list(self.labels), "manufacturer": self.manufacturer, "model": self.model, "name_by_user": self.name_by_user, diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index 19c30a43858..bfb1ebdb191 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -1,6 +1,7 @@ """Test device_registry API.""" import pytest +from pytest_unordered import unordered from homeassistant.components.config import device_registry from homeassistant.core import HomeAssistant @@ -64,6 +65,7 @@ async def test_list_devices( "entry_type": None, "hw_version": None, "identifiers": [["bridgeid", "0123"]], + "labels": [], "manufacturer": "manufacturer", "model": "model", "name_by_user": None, @@ -81,6 +83,7 @@ async def test_list_devices( "entry_type": dr.DeviceEntryType.SERVICE, "hw_version": None, "identifiers": [["bridgeid", "1234"]], + "labels": [], "manufacturer": "manufacturer", "model": "model", "name_by_user": None, @@ -111,6 +114,7 @@ async def test_list_devices( "hw_version": None, "id": device1.id, "identifiers": [["bridgeid", "0123"]], + "labels": [], "manufacturer": "manufacturer", "model": "model", "name_by_user": None, @@ -180,6 +184,45 @@ async def test_update_device( assert isinstance(device.disabled_by, (dr.DeviceEntryDisabler, type(None))) +async def test_update_device_labels( + hass: HomeAssistant, + client: MockHAClientWebSocket, + device_registry: dr.DeviceRegistry, +) -> None: + """Test update entry labels.""" + entry = MockConfigEntry(title=None) + entry.add_to_hass(hass) + device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("ethernet", "12:34:56:78:90:AB:CD:EF")}, + identifiers={("bridgeid", "0123")}, + manufacturer="manufacturer", + model="model", + ) + + assert not device.labels + + await client.send_json_auto_id( + { + "type": "config/device_registry/update", + "device_id": device.id, + "labels": ["label1", "label2"], + } + ) + + msg = await client.receive_json() + await hass.async_block_till_done() + assert len(device_registry.devices) == 1 + + device = device_registry.async_get_device( + identifiers={("bridgeid", "0123")}, + connections={("ethernet", "12:34:56:78:90:AB:CD:EF")}, + ) + + assert msg["result"]["labels"] == unordered(["label1", "label2"]) + assert device.labels == {"label1", "label2"} + + async def test_remove_config_entry_from_device( hass: HomeAssistant, hass_ws_client: WebSocketGenerator,