Add WebSocket support for handling labels on device registry (#113758)

This commit is contained in:
Franck Nijhof 2024-03-18 21:19:27 +01:00 committed by GitHub
parent f73f93913f
commit 51b8ffc69d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 0 deletions

View File

@ -69,6 +69,7 @@ def websocket_list_devices(
# We only allow setting disabled_by user via API. # We only allow setting disabled_by user via API.
# No Enum support like this in voluptuous, use .value # No Enum support like this in voluptuous, use .value
vol.Optional("disabled_by"): vol.Any(DeviceEntryDisabler.USER.value, None), vol.Optional("disabled_by"): vol.Any(DeviceEntryDisabler.USER.value, None),
vol.Optional("labels"): [str],
vol.Optional("name_by_user"): vol.Any(str, None), vol.Optional("name_by_user"): vol.Any(str, None),
} }
) )
@ -87,6 +88,10 @@ def websocket_update_device(
if msg.get("disabled_by") is not None: if msg.get("disabled_by") is not None:
msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"]) msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"])
if "labels" in msg:
# Convert labels to a set
msg["labels"] = set(msg["labels"])
entry = cast(DeviceEntry, registry.async_update_device(**msg)) entry = cast(DeviceEntry, registry.async_update_device(**msg))
connection.send_message(websocket_api.result_message(msg_id, entry.dict_repr)) connection.send_message(websocket_api.result_message(msg_id, entry.dict_repr))

View File

@ -276,6 +276,7 @@ class DeviceEntry:
"hw_version": self.hw_version, "hw_version": self.hw_version,
"id": self.id, "id": self.id,
"identifiers": list(self.identifiers), "identifiers": list(self.identifiers),
"labels": list(self.labels),
"manufacturer": self.manufacturer, "manufacturer": self.manufacturer,
"model": self.model, "model": self.model,
"name_by_user": self.name_by_user, "name_by_user": self.name_by_user,

View File

@ -1,6 +1,7 @@
"""Test device_registry API.""" """Test device_registry API."""
import pytest import pytest
from pytest_unordered import unordered
from homeassistant.components.config import device_registry from homeassistant.components.config import device_registry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -64,6 +65,7 @@ async def test_list_devices(
"entry_type": None, "entry_type": None,
"hw_version": None, "hw_version": None,
"identifiers": [["bridgeid", "0123"]], "identifiers": [["bridgeid", "0123"]],
"labels": [],
"manufacturer": "manufacturer", "manufacturer": "manufacturer",
"model": "model", "model": "model",
"name_by_user": None, "name_by_user": None,
@ -81,6 +83,7 @@ async def test_list_devices(
"entry_type": dr.DeviceEntryType.SERVICE, "entry_type": dr.DeviceEntryType.SERVICE,
"hw_version": None, "hw_version": None,
"identifiers": [["bridgeid", "1234"]], "identifiers": [["bridgeid", "1234"]],
"labels": [],
"manufacturer": "manufacturer", "manufacturer": "manufacturer",
"model": "model", "model": "model",
"name_by_user": None, "name_by_user": None,
@ -111,6 +114,7 @@ async def test_list_devices(
"hw_version": None, "hw_version": None,
"id": device1.id, "id": device1.id,
"identifiers": [["bridgeid", "0123"]], "identifiers": [["bridgeid", "0123"]],
"labels": [],
"manufacturer": "manufacturer", "manufacturer": "manufacturer",
"model": "model", "model": "model",
"name_by_user": None, "name_by_user": None,
@ -180,6 +184,45 @@ async def test_update_device(
assert isinstance(device.disabled_by, (dr.DeviceEntryDisabler, type(None))) assert isinstance(device.disabled_by, (dr.DeviceEntryDisabler, type(None)))
async def test_update_device_labels(
hass: HomeAssistant,
client: MockHAClientWebSocket,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test update entry labels."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("ethernet", "12:34:56:78:90:AB:CD:EF")},
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
)
assert not device.labels
await client.send_json_auto_id(
{
"type": "config/device_registry/update",
"device_id": device.id,
"labels": ["label1", "label2"],
}
)
msg = await client.receive_json()
await hass.async_block_till_done()
assert len(device_registry.devices) == 1
device = device_registry.async_get_device(
identifiers={("bridgeid", "0123")},
connections={("ethernet", "12:34:56:78:90:AB:CD:EF")},
)
assert msg["result"]["labels"] == unordered(["label1", "label2"])
assert device.labels == {"label1", "label2"}
async def test_remove_config_entry_from_device( async def test_remove_config_entry_from_device(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,