From 30baf333c358d1311dfd4af5ed90fe942f30f434 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 2 Dec 2020 21:20:14 +0100 Subject: [PATCH] Improve handling of disabled devices (#43864) --- .../components/config/entity_registry.py | 13 ++++ homeassistant/helpers/entity_registry.py | 13 +++- .../components/config/test_entity_registry.py | 63 ++++++++++++++++++- tests/helpers/test_entity_registry.py | 34 ++++++++-- 4 files changed, 113 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 73327ecf23c..8d1c488bfa0 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -107,6 +107,19 @@ async def websocket_update_entity(hass, connection, msg): ) return + if "disabled_by" in msg and msg["disabled_by"] is None: + entity = registry.entities[msg["entity_id"]] + if entity.device_id: + device_registry = await hass.helpers.device_registry.async_get_registry() + device = device_registry.async_get(entity.device_id) + if device.disabled: + connection.send_message( + websocket_api.error_message( + msg["id"], "invalid_info", "Device is disabled" + ) + ) + return + try: if changes: entry = registry.async_update_entity(msg["entity_id"], **changes) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 7e8700e8236..4582fc5f3b6 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -311,11 +311,18 @@ class EntityRegistry: device_registry = await self.hass.helpers.device_registry.async_get_registry() device = device_registry.async_get(event.data["device_id"]) if not device.disabled: + entities = async_entries_for_device( + self, event.data["device_id"], include_disabled_entities=True + ) + for entity in entities: + if entity.disabled_by != DISABLED_DEVICE: + continue + self.async_update_entity( # type: ignore + entity.entity_id, disabled_by=None + ) return - entities = async_entries_for_device( - self, event.data["device_id"], include_disabled_entities=True - ) + entities = async_entries_for_device(self, event.data["device_id"]) for entity in entities: self.async_update_entity( # type: ignore entity.entity_id, disabled_by=DISABLED_DEVICE diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index a506135c16d..93d33bc9562 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -7,7 +7,13 @@ from homeassistant.components.config import entity_registry from homeassistant.const import ATTR_ICON from homeassistant.helpers.entity_registry import RegistryEntry -from tests.common import MockConfigEntry, MockEntity, MockEntityPlatform, mock_registry +from tests.common import ( + MockConfigEntry, + MockEntity, + MockEntityPlatform, + mock_device_registry, + mock_registry, +) @pytest.fixture @@ -17,6 +23,12 @@ def client(hass, hass_ws_client): yield hass.loop.run_until_complete(hass_ws_client(hass)) +@pytest.fixture +def device_registry(hass): + """Return an empty, loaded, registry.""" + return mock_device_registry(hass) + + async def test_list_entities(hass, client): """Test list entries.""" entities = OrderedDict() @@ -282,6 +294,55 @@ async def test_update_entity_require_restart(hass, client): } +async def test_enable_entity_disabled_device(hass, client, device_registry): + """Test enabling entity of disabled device.""" + config_entry = MockConfigEntry(domain="test_platform") + config_entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id="1234", + connections={("ethernet", "12:34:56:78:90:AB:CD:EF")}, + identifiers={("bridgeid", "0123")}, + manufacturer="manufacturer", + model="model", + disabled_by="user", + ) + + mock_registry( + hass, + { + "test_domain.world": RegistryEntry( + config_entry_id=config_entry.entry_id, + entity_id="test_domain.world", + unique_id="1234", + # Using component.async_add_entities is equal to platform "domain" + platform="test_platform", + device_id=device.id, + ) + }, + ) + platform = MockEntityPlatform(hass) + entity = MockEntity(unique_id="1234") + await platform.async_add_entities([entity]) + + state = hass.states.get("test_domain.world") + assert state is not None + + # UPDATE DISABLED_BY TO NONE + await client.send_json( + { + "id": 8, + "type": "config/entity_registry/update", + "entity_id": "test_domain.world", + "disabled_by": None, + } + ) + + msg = await client.receive_json() + + assert not msg["success"] + + async def test_update_entity_no_changes(hass, client): """Test update entity with no changes.""" mock_registry( diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 960537e784c..19af3715160 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -711,31 +711,53 @@ async def test_remove_device_removes_entities(hass, registry): async def test_disable_device_disables_entities(hass, registry): - """Test that we remove entities tied to a device.""" + """Test that we disable entities tied to a device.""" device_registry = mock_device_registry(hass) config_entry = MockConfigEntry(domain="light") + config_entry.add_to_hass(hass) device_entry = device_registry.async_get_or_create( config_entry_id=config_entry.entry_id, connections={("mac", "12:34:56:AB:CD:EF")}, ) - entry = registry.async_get_or_create( + entry1 = registry.async_get_or_create( "light", "hue", "5678", config_entry=config_entry, device_id=device_entry.id, ) + entry2 = registry.async_get_or_create( + "light", + "hue", + "ABCD", + config_entry=config_entry, + device_id=device_entry.id, + disabled_by="user", + ) - assert not entry.disabled + assert not entry1.disabled + assert entry2.disabled device_registry.async_update_device(device_entry.id, disabled_by="user") await hass.async_block_till_done() - entry = registry.async_get(entry.entity_id) - assert entry.disabled - assert entry.disabled_by == "device" + entry1 = registry.async_get(entry1.entity_id) + assert entry1.disabled + assert entry1.disabled_by == "device" + entry2 = registry.async_get(entry2.entity_id) + assert entry2.disabled + assert entry2.disabled_by == "user" + + device_registry.async_update_device(device_entry.id, disabled_by=None) + await hass.async_block_till_done() + + entry1 = registry.async_get(entry1.entity_id) + assert not entry1.disabled + entry2 = registry.async_get(entry2.entity_id) + assert entry2.disabled + assert entry2.disabled_by == "user" async def test_disabled_entities_excluded_from_entity_list(hass, registry):