diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index b8d9944d7af..b45b6abe468 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -1,7 +1,6 @@ """Http views to control the config manager.""" import aiohttp.web_exceptions import voluptuous as vol -import voluptuous_serialize from homeassistant import config_entries, data_entry_flow from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT @@ -10,7 +9,6 @@ from homeassistant.components.http import HomeAssistantView from homeassistant.const import HTTP_FORBIDDEN, HTTP_NOT_FOUND from homeassistant.core import callback from homeassistant.exceptions import Unauthorized -import homeassistant.helpers.config_validation as cv from homeassistant.helpers.data_entry_flow import ( FlowManagerIndexView, FlowManagerResourceView, @@ -30,6 +28,7 @@ async def async_setup(hass): hass.http.register_view(OptionManagerFlowIndexView(hass.config_entries.options)) hass.http.register_view(OptionManagerFlowResourceView(hass.config_entries.options)) + hass.components.websocket_api.async_register_command(config_entry_disable) hass.components.websocket_api.async_register_command(config_entry_update) hass.components.websocket_api.async_register_command(config_entries_progress) hass.components.websocket_api.async_register_command(system_options_list) @@ -39,24 +38,6 @@ async def async_setup(hass): return True -def _prepare_json(result): - """Convert result for JSON.""" - if result["type"] != data_entry_flow.RESULT_TYPE_FORM: - return result - - data = result.copy() - - schema = data["data_schema"] - if schema is None: - data["data_schema"] = [] - else: - data["data_schema"] = voluptuous_serialize.convert( - schema, custom_serializer=cv.custom_serializer - ) - - return data - - class ConfigManagerEntryIndexView(HomeAssistantView): """View to get available config entries.""" @@ -265,6 +246,21 @@ async def system_options_list(hass, connection, msg): connection.send_result(msg["id"], entry.system_options.as_dict()) +def send_entry_not_found(connection, msg_id): + """Send Config entry not found error.""" + connection.send_error( + msg_id, websocket_api.const.ERR_NOT_FOUND, "Config entry not found" + ) + + +def get_entry(hass, connection, entry_id, msg_id): + """Get entry, send error message if it doesn't exist.""" + entry = hass.config_entries.async_get_entry(entry_id) + if entry is None: + send_entry_not_found(connection, msg_id) + return entry + + @websocket_api.require_admin @websocket_api.async_response @websocket_api.websocket_command( @@ -279,13 +275,10 @@ async def system_options_update(hass, connection, msg): changes = dict(msg) changes.pop("id") changes.pop("type") - entry_id = changes.pop("entry_id") - entry = hass.config_entries.async_get_entry(entry_id) + changes.pop("entry_id") + entry = get_entry(hass, connection, msg["entry_id"], msg["id"]) if entry is None: - connection.send_error( - msg["id"], websocket_api.const.ERR_NOT_FOUND, "Config entry not found" - ) return hass.config_entries.async_update_entry(entry, system_options=changes) @@ -302,20 +295,47 @@ async def config_entry_update(hass, connection, msg): changes = dict(msg) changes.pop("id") changes.pop("type") - entry_id = changes.pop("entry_id") - - entry = hass.config_entries.async_get_entry(entry_id) + changes.pop("entry_id") + entry = get_entry(hass, connection, msg["entry_id"], msg["id"]) if entry is None: - connection.send_error( - msg["id"], websocket_api.const.ERR_NOT_FOUND, "Config entry not found" - ) return hass.config_entries.async_update_entry(entry, **changes) connection.send_result(msg["id"], entry_json(entry)) +@websocket_api.require_admin +@websocket_api.async_response +@websocket_api.websocket_command( + { + "type": "config_entries/disable", + "entry_id": str, + # We only allow setting disabled_by user via API. + "disabled_by": vol.Any("user", None), + } +) +async def config_entry_disable(hass, connection, msg): + """Disable config entry.""" + disabled_by = msg["disabled_by"] + + result = False + try: + result = await hass.config_entries.async_set_disabled_by( + msg["entry_id"], disabled_by + ) + except config_entries.OperationNotAllowed: + # Failed to unload the config entry + pass + except config_entries.UnknownEntry: + send_entry_not_found(connection, msg["id"]) + return + + result = {"require_restart": not result} + + connection.send_result(msg["id"], result) + + @websocket_api.require_admin @websocket_api.async_response @websocket_api.websocket_command( @@ -333,9 +353,7 @@ async def ignore_config_flow(hass, connection, msg): ) if flow is None: - connection.send_error( - msg["id"], websocket_api.const.ERR_NOT_FOUND, "Config entry not found" - ) + send_entry_not_found(connection, msg["id"]) return if "unique_id" not in flow["context"]: @@ -357,7 +375,7 @@ def entry_json(entry: config_entries.ConfigEntry) -> dict: """Return JSON value of a config entry.""" handler = config_entries.HANDLERS.get(entry.domain) supports_options = ( - # Guard in case handler is no longer registered (custom compnoent etc) + # Guard in case handler is no longer registered (custom component etc) handler is not None # pylint: disable=comparison-with-callable and handler.async_get_options_flow @@ -372,4 +390,5 @@ def entry_json(entry: config_entries.ConfigEntry) -> dict: "connection_class": entry.connection_class, "supports_options": supports_options, "supports_unload": entry.supports_unload, + "disabled_by": entry.disabled_by, } diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 7225b7c375d..dbc0dd01454 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -11,6 +11,7 @@ import weakref import attr from homeassistant import data_entry_flow, loader +from homeassistant.const import EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.helpers import entity_registry @@ -68,6 +69,8 @@ ENTRY_STATE_SETUP_RETRY = "setup_retry" ENTRY_STATE_NOT_LOADED = "not_loaded" # An error occurred when trying to unload the entry ENTRY_STATE_FAILED_UNLOAD = "failed_unload" +# The config entry is disabled +ENTRY_STATE_DISABLED = "disabled" UNRECOVERABLE_STATES = (ENTRY_STATE_MIGRATION_ERROR, ENTRY_STATE_FAILED_UNLOAD) @@ -92,6 +95,8 @@ CONN_CLASS_LOCAL_POLL = "local_poll" CONN_CLASS_ASSUMED = "assumed" CONN_CLASS_UNKNOWN = "unknown" +DISABLED_USER = "user" + RELOAD_AFTER_UPDATE_DELAY = 30 @@ -126,6 +131,7 @@ class ConfigEntry: "source", "connection_class", "state", + "disabled_by", "_setup_lock", "update_listeners", "_async_cancel_retry_setup", @@ -144,6 +150,7 @@ class ConfigEntry: unique_id: Optional[str] = None, entry_id: Optional[str] = None, state: str = ENTRY_STATE_NOT_LOADED, + disabled_by: Optional[str] = None, ) -> None: """Initialize a config entry.""" # Unique id of the config entry @@ -179,6 +186,9 @@ class ConfigEntry: # Unique ID of this entry. self.unique_id = unique_id + # Config entry is disabled + self.disabled_by = disabled_by + # Supports unload self.supports_unload = False @@ -198,7 +208,7 @@ class ConfigEntry: tries: int = 0, ) -> None: """Set up an entry.""" - if self.source == SOURCE_IGNORE: + if self.source == SOURCE_IGNORE or self.disabled_by: return if integration is None: @@ -441,6 +451,7 @@ class ConfigEntry: "source": self.source, "connection_class": self.connection_class, "unique_id": self.unique_id, + "disabled_by": self.disabled_by, } @@ -711,6 +722,8 @@ class ConfigEntries: system_options=entry.get("system_options", {}), # New in 0.104 unique_id=entry.get("unique_id"), + # New in 2021.3 + disabled_by=entry.get("disabled_by"), ) for entry in config["entries"] ] @@ -759,13 +772,42 @@ class ConfigEntries: If an entry was not loaded, will just load. """ + entry = self.async_get_entry(entry_id) + + if entry is None: + raise UnknownEntry + unload_result = await self.async_unload(entry_id) - if not unload_result: + if not unload_result or entry.disabled_by: return unload_result return await self.async_setup(entry_id) + async def async_set_disabled_by( + self, entry_id: str, disabled_by: Optional[str] + ) -> bool: + """Disable an entry. + + If disabled_by is changed, the config entry will be reloaded. + """ + entry = self.async_get_entry(entry_id) + + if entry is None: + raise UnknownEntry + + if entry.disabled_by == disabled_by: + return True + + entry.disabled_by = disabled_by + self._async_schedule_save() + + self.hass.bus.async_fire( + EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, {"config_entry_id": entry_id} + ) + + return await self.async_reload(entry_id) + @callback def async_update_entry( self, diff --git a/homeassistant/const.py b/homeassistant/const.py index 4406c8bdfc3..a0aafaad3ce 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -202,6 +202,7 @@ CONF_ZONE = "zone" # #### EVENTS #### EVENT_CALL_SERVICE = "call_service" EVENT_COMPONENT_LOADED = "component_loaded" +EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED = "config_entry_disabled_by_updated" EVENT_CORE_CONFIG_UPDATE = "core_config_updated" EVENT_HOMEASSISTANT_CLOSE = "homeassistant_close" EVENT_HOMEASSISTANT_START = "homeassistant_start" diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 3dd44364604..705f6cdd89a 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -6,7 +6,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, import attr -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED +from homeassistant.const import ( + EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, + EVENT_HOMEASSISTANT_STARTED, +) from homeassistant.core import Event, callback from homeassistant.loader import bind_hass import homeassistant.util.uuid as uuid_util @@ -37,6 +40,7 @@ IDX_IDENTIFIERS = "identifiers" REGISTERED_DEVICE = "registered" DELETED_DEVICE = "deleted" +DISABLED_CONFIG_ENTRY = "config_entry" DISABLED_INTEGRATION = "integration" DISABLED_USER = "user" @@ -65,6 +69,7 @@ class DeviceEntry: default=None, validator=attr.validators.in_( ( + DISABLED_CONFIG_ENTRY, DISABLED_INTEGRATION, DISABLED_USER, None, @@ -138,6 +143,10 @@ class DeviceRegistry: self.hass = hass self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._clear_index() + self.hass.bus.async_listen( + EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, + self.async_config_entry_disabled_by_changed, + ) @callback def async_get(self, device_id: str) -> Optional[DeviceEntry]: @@ -609,6 +618,38 @@ class DeviceRegistry: if area_id == device.area_id: self._async_update_device(dev_id, area_id=None) + @callback + def async_config_entry_disabled_by_changed(self, event: Event) -> None: + """Handle a config entry being disabled or enabled. + + Disable devices in the registry that are associated to a config entry when + the config entry is disabled. + """ + config_entry = self.hass.config_entries.async_get_entry( + event.data["config_entry_id"] + ) + + # The config entry may be deleted already if the event handling is late + if not config_entry: + return + + if not config_entry.disabled_by: + devices = async_entries_for_config_entry( + self, event.data["config_entry_id"] + ) + for device in devices: + if device.disabled_by != DISABLED_CONFIG_ENTRY: + continue + self.async_update_device(device.id, disabled_by=None) + return + + devices = async_entries_for_config_entry(self, event.data["config_entry_id"]) + for device in devices: + if device.disabled: + # Entity already disabled, do not overwrite + continue + self.async_update_device(device.id, disabled_by=DISABLED_CONFIG_ENTRY) + @callback def async_get(hass: HomeAssistantType) -> DeviceRegistry: diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 0938ea9165f..c86bd64d73e 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -31,6 +31,7 @@ from homeassistant.const import ( ATTR_RESTORED, ATTR_SUPPORTED_FEATURES, ATTR_UNIT_OF_MEASUREMENT, + EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE, ) @@ -157,6 +158,10 @@ class EntityRegistry: self.hass.bus.async_listen( EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified ) + self.hass.bus.async_listen( + EVENT_CONFIG_ENTRY_DISABLED_BY_UPDATED, + self.async_config_entry_disabled_by_changed, + ) @callback def async_get_device_class_lookup(self, domain_device_classes: set) -> dict: @@ -349,10 +354,49 @@ class EntityRegistry: self.async_update_entity(entity.entity_id, disabled_by=None) return + if device.disabled_by == dr.DISABLED_CONFIG_ENTRY: + # Handled by async_config_entry_disabled + return + + # Fetch entities which are not already disabled entities = async_entries_for_device(self, event.data["device_id"]) for entity in entities: self.async_update_entity(entity.entity_id, disabled_by=DISABLED_DEVICE) + @callback + def async_config_entry_disabled_by_changed(self, event: Event) -> None: + """Handle a config entry being disabled or enabled. + + Disable entities in the registry that are associated to a config entry when + the config entry is disabled. + """ + config_entry = self.hass.config_entries.async_get_entry( + event.data["config_entry_id"] + ) + + # The config entry may be deleted already if the event handling is late + if not config_entry: + return + + if not config_entry.disabled_by: + entities = async_entries_for_config_entry( + self, event.data["config_entry_id"] + ) + for entity in entities: + if entity.disabled_by != DISABLED_CONFIG_ENTRY: + continue + self.async_update_entity(entity.entity_id, disabled_by=None) + return + + entities = async_entries_for_config_entry(self, event.data["config_entry_id"]) + for entity in entities: + if entity.disabled: + # Entity already disabled, do not overwrite + continue + self.async_update_entity( + entity.entity_id, disabled_by=DISABLED_CONFIG_ENTRY + ) + @callback def async_update_entity( self, diff --git a/tests/common.py b/tests/common.py index c07716dbfc9..52d368853b3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -749,6 +749,7 @@ class MockConfigEntry(config_entries.ConfigEntry): system_options={}, connection_class=config_entries.CONN_CLASS_UNKNOWN, unique_id=None, + disabled_by=None, ): """Initialize a mock config entry.""" kwargs = { @@ -761,6 +762,7 @@ class MockConfigEntry(config_entries.ConfigEntry): "title": title, "connection_class": connection_class, "unique_id": unique_id, + "disabled_by": disabled_by, } if source is not None: kwargs["source"] = source diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index 87b1559a21b..6bb1f1885eb 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -68,6 +68,12 @@ async def test_get_entries(hass, client): state=core_ce.ENTRY_STATE_LOADED, connection_class=core_ce.CONN_CLASS_ASSUMED, ).add_to_hass(hass) + MockConfigEntry( + domain="comp3", + title="Test 3", + source="bla3", + disabled_by="user", + ).add_to_hass(hass) resp = await client.get("/api/config/config_entries/entry") assert resp.status == 200 @@ -83,6 +89,7 @@ async def test_get_entries(hass, client): "connection_class": "local_poll", "supports_options": True, "supports_unload": True, + "disabled_by": None, }, { "domain": "comp2", @@ -92,6 +99,17 @@ async def test_get_entries(hass, client): "connection_class": "assumed", "supports_options": False, "supports_unload": False, + "disabled_by": None, + }, + { + "domain": "comp3", + "title": "Test 3", + "source": "bla3", + "state": "not_loaded", + "connection_class": "unknown", + "supports_options": False, + "supports_unload": False, + "disabled_by": "user", }, ] @@ -680,6 +698,25 @@ async def test_update_system_options(hass, hass_ws_client): assert entry.system_options.disable_new_entities +async def test_update_system_options_nonexisting(hass, hass_ws_client): + """Test that we can update entry.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + + await ws_client.send_json( + { + "id": 5, + "type": "config_entries/system_options/update", + "entry_id": "non_existing", + "disable_new_entities": True, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "not_found" + + async def test_update_entry(hass, hass_ws_client): """Test that we can update entry.""" assert await async_setup_component(hass, "config", {}) @@ -722,6 +759,83 @@ async def test_update_entry_nonexisting(hass, hass_ws_client): assert response["error"]["code"] == "not_found" +async def test_disable_entry(hass, hass_ws_client): + """Test that we can disable entry.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + + entry = MockConfigEntry(domain="demo", state="loaded") + entry.add_to_hass(hass) + assert entry.disabled_by is None + + # Disable + await ws_client.send_json( + { + "id": 5, + "type": "config_entries/disable", + "entry_id": entry.entry_id, + "disabled_by": "user", + } + ) + response = await ws_client.receive_json() + + assert response["success"] + assert response["result"] == {"require_restart": True} + assert entry.disabled_by == "user" + assert entry.state == "failed_unload" + + # Enable + await ws_client.send_json( + { + "id": 6, + "type": "config_entries/disable", + "entry_id": entry.entry_id, + "disabled_by": None, + } + ) + response = await ws_client.receive_json() + + assert response["success"] + assert response["result"] == {"require_restart": True} + assert entry.disabled_by is None + assert entry.state == "failed_unload" + + # Enable again -> no op + await ws_client.send_json( + { + "id": 7, + "type": "config_entries/disable", + "entry_id": entry.entry_id, + "disabled_by": None, + } + ) + response = await ws_client.receive_json() + + assert response["success"] + assert response["result"] == {"require_restart": False} + assert entry.disabled_by is None + assert entry.state == "failed_unload" + + +async def test_disable_entry_nonexisting(hass, hass_ws_client): + """Test that we can disable entry.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + + await ws_client.send_json( + { + "id": 5, + "type": "config_entries/disable", + "entry_id": "non_existing", + "disabled_by": "user", + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "not_found" + + async def test_ignore_flow(hass, hass_ws_client): """Test we can ignore a flow.""" assert await async_setup_component(hass, "config", {}) @@ -763,3 +877,22 @@ async def test_ignore_flow(hass, hass_ws_client): assert entry.source == "ignore" assert entry.unique_id == "mock-unique-id" assert entry.title == "Test Integration" + + +async def test_ignore_flow_nonexisting(hass, hass_ws_client): + """Test we can ignore a flow.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + + await ws_client.send_json( + { + "id": 5, + "type": "config_entries/ignore_flow", + "flow_id": "non_existing", + "title": "Test Integration", + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "not_found" diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index bc0e1c3bec9..965ebcd3e23 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -1209,3 +1209,41 @@ async def test_verify_suggested_area_does_not_overwrite_area_id( suggested_area="New Game Room", ) assert entry2.area_id == game_room_area.id + + +async def test_disable_config_entry_disables_devices(hass, registry): + """Test that we disable entities tied to a config entry.""" + config_entry = MockConfigEntry(domain="light") + config_entry.add_to_hass(hass) + + entry1 = registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + entry2 = registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "34:56:AB:CD:EF:12")}, + disabled_by="user", + ) + + assert not entry1.disabled + assert entry2.disabled + + await hass.config_entries.async_set_disabled_by(config_entry.entry_id, "user") + await hass.async_block_till_done() + + entry1 = registry.async_get(entry1.id) + assert entry1.disabled + assert entry1.disabled_by == "config_entry" + entry2 = registry.async_get(entry2.id) + assert entry2.disabled + assert entry2.disabled_by == "user" + + await hass.config_entries.async_set_disabled_by(config_entry.entry_id, None) + await hass.async_block_till_done() + + entry1 = registry.async_get(entry1.id) + assert not entry1.disabled + entry2 = registry.async_get(entry2.id) + assert entry2.disabled + assert entry2.disabled_by == "user" diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 71cfb331591..86cdab82238 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -757,9 +757,18 @@ async def test_disable_device_disables_entities(hass, registry): device_id=device_entry.id, disabled_by="user", ) + entry3 = registry.async_get_or_create( + "light", + "hue", + "EFGH", + config_entry=config_entry, + device_id=device_entry.id, + disabled_by="config_entry", + ) assert not entry1.disabled assert entry2.disabled + assert entry3.disabled device_registry.async_update_device(device_entry.id, disabled_by="user") await hass.async_block_till_done() @@ -770,6 +779,9 @@ async def test_disable_device_disables_entities(hass, registry): entry2 = registry.async_get(entry2.entity_id) assert entry2.disabled assert entry2.disabled_by == "user" + entry3 = registry.async_get(entry3.entity_id) + assert entry3.disabled + assert entry3.disabled_by == "config_entry" device_registry.async_update_device(device_entry.id, disabled_by=None) await hass.async_block_till_done() @@ -779,6 +791,74 @@ async def test_disable_device_disables_entities(hass, registry): entry2 = registry.async_get(entry2.entity_id) assert entry2.disabled assert entry2.disabled_by == "user" + entry3 = registry.async_get(entry3.entity_id) + assert entry3.disabled + assert entry3.disabled_by == "config_entry" + + +async def test_disable_config_entry_disables_entities(hass, registry): + """Test that we disable entities tied to a config entry.""" + device_registry = mock_device_registry(hass) + config_entry = MockConfigEntry(domain="light") + config_entry.add_to_hass(hass) + + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("mac", "12:34:56:AB:CD:EF")}, + ) + + entry1 = registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=config_entry, + device_id=device_entry.id, + ) + entry2 = registry.async_get_or_create( + "light", + "hue", + "ABCD", + config_entry=config_entry, + device_id=device_entry.id, + disabled_by="user", + ) + entry3 = registry.async_get_or_create( + "light", + "hue", + "EFGH", + config_entry=config_entry, + device_id=device_entry.id, + disabled_by="device", + ) + + assert not entry1.disabled + assert entry2.disabled + assert entry3.disabled + + await hass.config_entries.async_set_disabled_by(config_entry.entry_id, "user") + await hass.async_block_till_done() + + entry1 = registry.async_get(entry1.entity_id) + assert entry1.disabled + assert entry1.disabled_by == "config_entry" + entry2 = registry.async_get(entry2.entity_id) + assert entry2.disabled + assert entry2.disabled_by == "user" + entry3 = registry.async_get(entry3.entity_id) + assert entry3.disabled + assert entry3.disabled_by == "device" + + await hass.config_entries.async_set_disabled_by(config_entry.entry_id, None) + await hass.async_block_till_done() + + entry1 = registry.async_get(entry1.entity_id) + assert not entry1.disabled + entry2 = registry.async_get(entry2.entity_id) + assert entry2.disabled + assert entry2.disabled_by == "user" + # The device was re-enabled, so entity disabled by the device will be re-enabled too + entry3 = registry.async_get(entry3.entity_id) + assert not entry3.disabled_by async def test_disabled_entities_excluded_from_entity_list(hass, registry): diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 435f2a11cc2..8a479a802e4 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1108,6 +1108,110 @@ async def test_entry_reload_error(hass, manager, state): assert entry.state == state +async def test_entry_disable_succeed(hass, manager): + """Test that we can disable an entry.""" + entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED) + entry.add_to_hass(hass) + + async_setup = AsyncMock(return_value=True) + async_setup_entry = AsyncMock(return_value=True) + async_unload_entry = AsyncMock(return_value=True) + + mock_integration( + hass, + MockModule( + "comp", + async_setup=async_setup, + async_setup_entry=async_setup_entry, + async_unload_entry=async_unload_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + # Disable + assert await manager.async_set_disabled_by( + entry.entry_id, config_entries.DISABLED_USER + ) + assert len(async_unload_entry.mock_calls) == 1 + assert len(async_setup.mock_calls) == 0 + assert len(async_setup_entry.mock_calls) == 0 + assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED + + # Enable + assert await manager.async_set_disabled_by(entry.entry_id, None) + assert len(async_unload_entry.mock_calls) == 1 + assert len(async_setup.mock_calls) == 1 + assert len(async_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + +async def test_entry_disable_without_reload_support(hass, manager): + """Test that we can disable an entry without reload support.""" + entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED) + entry.add_to_hass(hass) + + async_setup = AsyncMock(return_value=True) + async_setup_entry = AsyncMock(return_value=True) + + mock_integration( + hass, + MockModule( + "comp", + async_setup=async_setup, + async_setup_entry=async_setup_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + # Disable + assert not await manager.async_set_disabled_by( + entry.entry_id, config_entries.DISABLED_USER + ) + assert len(async_setup.mock_calls) == 0 + assert len(async_setup_entry.mock_calls) == 0 + assert entry.state == config_entries.ENTRY_STATE_FAILED_UNLOAD + + # Enable + with pytest.raises(config_entries.OperationNotAllowed): + await manager.async_set_disabled_by(entry.entry_id, None) + assert len(async_setup.mock_calls) == 0 + assert len(async_setup_entry.mock_calls) == 0 + assert entry.state == config_entries.ENTRY_STATE_FAILED_UNLOAD + + +async def test_entry_enable_without_reload_support(hass, manager): + """Test that we can disable an entry without reload support.""" + entry = MockConfigEntry(domain="comp", disabled_by=config_entries.DISABLED_USER) + entry.add_to_hass(hass) + + async_setup = AsyncMock(return_value=True) + async_setup_entry = AsyncMock(return_value=True) + + mock_integration( + hass, + MockModule( + "comp", + async_setup=async_setup, + async_setup_entry=async_setup_entry, + ), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + # Enable + assert await manager.async_set_disabled_by(entry.entry_id, None) + assert len(async_setup.mock_calls) == 1 + assert len(async_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_LOADED + + # Disable + assert not await manager.async_set_disabled_by( + entry.entry_id, config_entries.DISABLED_USER + ) + assert len(async_setup.mock_calls) == 1 + assert len(async_setup_entry.mock_calls) == 1 + assert entry.state == config_entries.ENTRY_STATE_FAILED_UNLOAD + + async def test_init_custom_integration(hass): """Test initializing flow for custom integration.""" integration = loader.Integration(