Entity registry api update disable (#26015)

* Clean up entity registry WS commands

* Allow updating disabled_by in entity registry

* Allow changing disabled_by via API

* Update tests/components/config/test_entity_registry.py

Co-Authored-By: Robert Svensson <Kane610@users.noreply.github.com>
This commit is contained in:
Paulus Schoutsen 2019-08-16 16:22:45 -07:00 committed by GitHub
parent d4046cb6e4
commit eba6caf8a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 53 deletions

View File

@ -11,51 +11,18 @@ from homeassistant.components.websocket_api.decorators import (
)
from homeassistant.helpers import config_validation as cv
WS_TYPE_LIST = "config/entity_registry/list"
SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_LIST}
)
WS_TYPE_GET = "config/entity_registry/get"
SCHEMA_WS_GET = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_GET, vol.Required("entity_id"): cv.entity_id}
)
WS_TYPE_UPDATE = "config/entity_registry/update"
SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): WS_TYPE_UPDATE,
vol.Required("entity_id"): cv.entity_id,
# If passed in, we update value. Passing None will remove old value.
vol.Optional("name"): vol.Any(str, None),
vol.Optional("new_entity_id"): str,
}
)
WS_TYPE_REMOVE = "config/entity_registry/remove"
SCHEMA_WS_REMOVE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_REMOVE, vol.Required("entity_id"): cv.entity_id}
)
async def async_setup(hass):
"""Enable the Entity Registry views."""
hass.components.websocket_api.async_register_command(
WS_TYPE_LIST, websocket_list_entities, SCHEMA_WS_LIST
)
hass.components.websocket_api.async_register_command(
WS_TYPE_GET, websocket_get_entity, SCHEMA_WS_GET
)
hass.components.websocket_api.async_register_command(
WS_TYPE_UPDATE, websocket_update_entity, SCHEMA_WS_UPDATE
)
hass.components.websocket_api.async_register_command(
WS_TYPE_REMOVE, websocket_remove_entity, SCHEMA_WS_REMOVE
)
hass.components.websocket_api.async_register_command(websocket_list_entities)
hass.components.websocket_api.async_register_command(websocket_get_entity)
hass.components.websocket_api.async_register_command(websocket_update_entity)
hass.components.websocket_api.async_register_command(websocket_remove_entity)
return True
@async_response
@websocket_api.websocket_command({vol.Required("type"): "config/entity_registry/list"})
async def websocket_list_entities(hass, connection, msg):
"""Handle list registry entries command.
@ -70,6 +37,12 @@ async def websocket_list_entities(hass, connection, msg):
@async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "config/entity_registry/get",
vol.Required("entity_id"): cv.entity_id,
}
)
async def websocket_get_entity(hass, connection, msg):
"""Handle get entity registry entry command.
@ -89,6 +62,17 @@ async def websocket_get_entity(hass, connection, msg):
@require_admin
@async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "config/entity_registry/update",
vol.Required("entity_id"): cv.entity_id,
# If passed in, we update value. Passing None will remove old value.
vol.Optional("name"): vol.Any(str, None),
vol.Optional("new_entity_id"): str,
# We only allow setting disabled_by user via API.
vol.Optional("disabled_by"): vol.Any("user", None),
}
)
async def websocket_update_entity(hass, connection, msg):
"""Handle update entity websocket command.
@ -107,6 +91,9 @@ async def websocket_update_entity(hass, connection, msg):
if "name" in msg:
changes["name"] = msg["name"]
if "disabled_by" in msg:
changes["disabled_by"] = msg["disabled_by"]
if "new_entity_id" in msg and msg["new_entity_id"] != msg["entity_id"]:
changes["new_entity_id"] = msg["new_entity_id"]
if hass.states.get(msg["new_entity_id"]) is not None:
@ -132,6 +119,12 @@ async def websocket_update_entity(hass, connection, msg):
@require_admin
@async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "config/entity_registry/remove",
vol.Required("entity_id"): cv.entity_id,
}
)
async def websocket_remove_entity(hass, connection, msg):
"""Handle remove entity websocket command.

View File

@ -201,7 +201,13 @@ class EntityRegistry:
@callback
def async_update_entity(
self, entity_id, *, name=_UNDEF, new_entity_id=_UNDEF, new_unique_id=_UNDEF
self,
entity_id,
*,
name=_UNDEF,
new_entity_id=_UNDEF,
new_unique_id=_UNDEF,
disabled_by=_UNDEF,
):
"""Update properties of an entity."""
return self._async_update_entity(
@ -209,6 +215,7 @@ class EntityRegistry:
name=name,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
)
@callback
@ -221,20 +228,21 @@ class EntityRegistry:
new_entity_id=_UNDEF,
device_id=_UNDEF,
new_unique_id=_UNDEF,
disabled_by=_UNDEF,
):
"""Private facing update properties method."""
old = self.entities[entity_id]
changes = {}
if name is not _UNDEF and name != old.name:
changes["name"] = name
if config_entry_id is not _UNDEF and config_entry_id != old.config_entry_id:
changes["config_entry_id"] = config_entry_id
if device_id is not _UNDEF and device_id != old.device_id:
changes["device_id"] = device_id
for attr_name, value in (
("name", name),
("config_entry_id", config_entry_id),
("device_id", device_id),
("disabled_by", disabled_by),
):
if value is not _UNDEF and value != getattr(old, attr_name):
changes[attr_name] = value
if new_entity_id is not _UNDEF and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id):

View File

@ -105,9 +105,9 @@ async def test_get_entity(hass, client):
}
async def test_update_entity_name(hass, client):
"""Test updating entity name."""
mock_registry(
async def test_update_entity(hass, client):
"""Test updating entity."""
registry = mock_registry(
hass,
{
"test_domain.world": RegistryEntry(
@ -133,6 +133,32 @@ async def test_update_entity_name(hass, client):
"type": "config/entity_registry/update",
"entity_id": "test_domain.world",
"name": "after update",
"disabled_by": "user",
}
)
msg = await client.receive_json()
assert msg["result"] == {
"config_entry_id": None,
"device_id": None,
"disabled_by": "user",
"platform": "test_platform",
"entity_id": "test_domain.world",
"name": "after update",
}
state = hass.states.get("test_domain.world")
assert state.name == "after update"
assert registry.entities["test_domain.world"].disabled_by == "user"
await client.send_json(
{
"id": 7,
"type": "config/entity_registry/update",
"entity_id": "test_domain.world",
"disabled_by": None,
}
)
@ -147,9 +173,6 @@ async def test_update_entity_name(hass, client):
"name": "after update",
}
state = hass.states.get("test_domain.world")
assert state.name == "after update"
async def test_update_entity_no_changes(hass, client):
"""Test update entity with no changes."""

View File

@ -354,6 +354,26 @@ async def test_update_entity_unique_id_conflict(registry):
assert mock_schedule_save.call_count == 0
async def test_update_entity(registry):
"""Test updating entity."""
entry = registry.async_get_or_create(
"light", "hue", "5678", config_entry_id="mock-id-1"
)
for attr_name, new_value in (
("name", "new name"),
("disabled_by", entity_registry.DISABLED_USER),
):
changes = {attr_name: new_value}
updated_entry = registry.async_update_entity(entry.entity_id, **changes)
assert updated_entry != entry
assert getattr(updated_entry, attr_name) == new_value
assert getattr(updated_entry, attr_name) != getattr(entry, attr_name)
entry = updated_entry
async def test_disabled_by(registry):
"""Test that we can disable an entry when we create it."""
entry = registry.async_get_or_create("light", "hue", "5678", disabled_by="hass")