diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 431723893c1..125b2260f08 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -11,51 +11,18 @@ from homeassistant.components.websocket_api.decorators import ( ) from homeassistant.helpers import config_validation as cv -WS_TYPE_LIST = "config/entity_registry/list" -SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_LIST} -) - -WS_TYPE_GET = "config/entity_registry/get" -SCHEMA_WS_GET = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_GET, vol.Required("entity_id"): cv.entity_id} -) - -WS_TYPE_UPDATE = "config/entity_registry/update" -SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - { - vol.Required("type"): WS_TYPE_UPDATE, - vol.Required("entity_id"): cv.entity_id, - # If passed in, we update value. Passing None will remove old value. - vol.Optional("name"): vol.Any(str, None), - vol.Optional("new_entity_id"): str, - } -) - -WS_TYPE_REMOVE = "config/entity_registry/remove" -SCHEMA_WS_REMOVE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - {vol.Required("type"): WS_TYPE_REMOVE, vol.Required("entity_id"): cv.entity_id} -) - async def async_setup(hass): """Enable the Entity Registry views.""" - hass.components.websocket_api.async_register_command( - WS_TYPE_LIST, websocket_list_entities, SCHEMA_WS_LIST - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_GET, websocket_get_entity, SCHEMA_WS_GET - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_UPDATE, websocket_update_entity, SCHEMA_WS_UPDATE - ) - hass.components.websocket_api.async_register_command( - WS_TYPE_REMOVE, websocket_remove_entity, SCHEMA_WS_REMOVE - ) + hass.components.websocket_api.async_register_command(websocket_list_entities) + hass.components.websocket_api.async_register_command(websocket_get_entity) + hass.components.websocket_api.async_register_command(websocket_update_entity) + hass.components.websocket_api.async_register_command(websocket_remove_entity) return True @async_response +@websocket_api.websocket_command({vol.Required("type"): "config/entity_registry/list"}) async def websocket_list_entities(hass, connection, msg): """Handle list registry entries command. @@ -70,6 +37,12 @@ async def websocket_list_entities(hass, connection, msg): @async_response +@websocket_api.websocket_command( + { + vol.Required("type"): "config/entity_registry/get", + vol.Required("entity_id"): cv.entity_id, + } +) async def websocket_get_entity(hass, connection, msg): """Handle get entity registry entry command. @@ -89,6 +62,17 @@ async def websocket_get_entity(hass, connection, msg): @require_admin @async_response +@websocket_api.websocket_command( + { + vol.Required("type"): "config/entity_registry/update", + vol.Required("entity_id"): cv.entity_id, + # If passed in, we update value. Passing None will remove old value. + vol.Optional("name"): vol.Any(str, None), + vol.Optional("new_entity_id"): str, + # We only allow setting disabled_by user via API. + vol.Optional("disabled_by"): vol.Any("user", None), + } +) async def websocket_update_entity(hass, connection, msg): """Handle update entity websocket command. @@ -107,6 +91,9 @@ async def websocket_update_entity(hass, connection, msg): if "name" in msg: changes["name"] = msg["name"] + if "disabled_by" in msg: + changes["disabled_by"] = msg["disabled_by"] + if "new_entity_id" in msg and msg["new_entity_id"] != msg["entity_id"]: changes["new_entity_id"] = msg["new_entity_id"] if hass.states.get(msg["new_entity_id"]) is not None: @@ -132,6 +119,12 @@ async def websocket_update_entity(hass, connection, msg): @require_admin @async_response +@websocket_api.websocket_command( + { + vol.Required("type"): "config/entity_registry/remove", + vol.Required("entity_id"): cv.entity_id, + } +) async def websocket_remove_entity(hass, connection, msg): """Handle remove entity websocket command. diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 97cc213aa66..8ef41eef9f8 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -201,7 +201,13 @@ class EntityRegistry: @callback def async_update_entity( - self, entity_id, *, name=_UNDEF, new_entity_id=_UNDEF, new_unique_id=_UNDEF + self, + entity_id, + *, + name=_UNDEF, + new_entity_id=_UNDEF, + new_unique_id=_UNDEF, + disabled_by=_UNDEF, ): """Update properties of an entity.""" return self._async_update_entity( @@ -209,6 +215,7 @@ class EntityRegistry: name=name, new_entity_id=new_entity_id, new_unique_id=new_unique_id, + disabled_by=disabled_by, ) @callback @@ -221,20 +228,21 @@ class EntityRegistry: new_entity_id=_UNDEF, device_id=_UNDEF, new_unique_id=_UNDEF, + disabled_by=_UNDEF, ): """Private facing update properties method.""" old = self.entities[entity_id] changes = {} - if name is not _UNDEF and name != old.name: - changes["name"] = name - - if config_entry_id is not _UNDEF and config_entry_id != old.config_entry_id: - changes["config_entry_id"] = config_entry_id - - if device_id is not _UNDEF and device_id != old.device_id: - changes["device_id"] = device_id + for attr_name, value in ( + ("name", name), + ("config_entry_id", config_entry_id), + ("device_id", device_id), + ("disabled_by", disabled_by), + ): + if value is not _UNDEF and value != getattr(old, attr_name): + changes[attr_name] = value if new_entity_id is not _UNDEF and new_entity_id != old.entity_id: if self.async_is_registered(new_entity_id): diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 5f8c6f51acb..f18abe9b0e2 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -105,9 +105,9 @@ async def test_get_entity(hass, client): } -async def test_update_entity_name(hass, client): - """Test updating entity name.""" - mock_registry( +async def test_update_entity(hass, client): + """Test updating entity.""" + registry = mock_registry( hass, { "test_domain.world": RegistryEntry( @@ -133,6 +133,32 @@ async def test_update_entity_name(hass, client): "type": "config/entity_registry/update", "entity_id": "test_domain.world", "name": "after update", + "disabled_by": "user", + } + ) + + msg = await client.receive_json() + + assert msg["result"] == { + "config_entry_id": None, + "device_id": None, + "disabled_by": "user", + "platform": "test_platform", + "entity_id": "test_domain.world", + "name": "after update", + } + + state = hass.states.get("test_domain.world") + assert state.name == "after update" + + assert registry.entities["test_domain.world"].disabled_by == "user" + + await client.send_json( + { + "id": 7, + "type": "config/entity_registry/update", + "entity_id": "test_domain.world", + "disabled_by": None, } ) @@ -147,9 +173,6 @@ async def test_update_entity_name(hass, client): "name": "after update", } - state = hass.states.get("test_domain.world") - assert state.name == "after update" - async def test_update_entity_no_changes(hass, client): """Test update entity with no changes.""" diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 88131a58de0..ce05e914b3d 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -354,6 +354,26 @@ async def test_update_entity_unique_id_conflict(registry): assert mock_schedule_save.call_count == 0 +async def test_update_entity(registry): + """Test updating entity.""" + entry = registry.async_get_or_create( + "light", "hue", "5678", config_entry_id="mock-id-1" + ) + + for attr_name, new_value in ( + ("name", "new name"), + ("disabled_by", entity_registry.DISABLED_USER), + ): + changes = {attr_name: new_value} + updated_entry = registry.async_update_entity(entry.entity_id, **changes) + + assert updated_entry != entry + assert getattr(updated_entry, attr_name) == new_value + assert getattr(updated_entry, attr_name) != getattr(entry, attr_name) + + entry = updated_entry + + async def test_disabled_by(registry): """Test that we can disable an entry when we create it.""" entry = registry.async_get_or_create("light", "hue", "5678", disabled_by="hass")