diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index e5bf9e9b93d..c3d20fb0f16 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -387,6 +387,7 @@ def entry_json(entry: config_entries.ConfigEntry) -> dict: "source": entry.source, "state": entry.state.value, "supports_options": supports_options, + "supports_remove_device": entry.supports_remove_device, "supports_unload": entry.supports_unload, "pref_disable_new_entities": entry.pref_disable_new_entities, "pref_disable_polling": entry.pref_disable_polling, diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index 686fffec252..e811d43d502 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -1,16 +1,12 @@ """HTTP views to interact with the device registry.""" import voluptuous as vol +from homeassistant import loader from homeassistant.components import websocket_api -from homeassistant.components.websocket_api.decorators import ( - async_response, - require_admin, -) -from homeassistant.core import callback -from homeassistant.helpers.device_registry import ( - DeviceEntryDisabler, - async_get_registry, -) +from homeassistant.components.websocket_api.decorators import require_admin +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.device_registry import DeviceEntryDisabler, async_get WS_TYPE_LIST = "config/device_registry/list" SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( @@ -39,13 +35,16 @@ async def async_setup(hass): websocket_api.async_register_command( hass, WS_TYPE_UPDATE, websocket_update_device, SCHEMA_WS_UPDATE ) + websocket_api.async_register_command( + hass, websocket_remove_config_entry_from_device + ) return True -@async_response -async def websocket_list_devices(hass, connection, msg): +@callback +def websocket_list_devices(hass, connection, msg): """Handle list devices command.""" - registry = await async_get_registry(hass) + registry = async_get(hass) connection.send_message( websocket_api.result_message( msg["id"], [_entry_dict(entry) for entry in registry.devices.values()] @@ -54,10 +53,10 @@ async def websocket_list_devices(hass, connection, msg): @require_admin -@async_response -async def websocket_update_device(hass, connection, msg): +@callback +def websocket_update_device(hass, connection, msg): """Handle update area websocket command.""" - registry = await async_get_registry(hass) + registry = async_get(hass) msg.pop("type") msg_id = msg.pop("id") @@ -70,6 +69,57 @@ async def websocket_update_device(hass, connection, msg): connection.send_message(websocket_api.result_message(msg_id, _entry_dict(entry))) +@websocket_api.require_admin +@websocket_api.websocket_command( + { + "type": "config/device_registry/remove_config_entry", + "device_id": str, + "config_entry_id": str, + } +) +@websocket_api.async_response +async def websocket_remove_config_entry_from_device( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """Remove config entry from a device.""" + registry = async_get(hass) + config_entry_id = msg["config_entry_id"] + device_id = msg["device_id"] + + if (config_entry := hass.config_entries.async_get_entry(config_entry_id)) is None: + raise HomeAssistantError("Unknown config entry") + + if not config_entry.supports_remove_device: + raise HomeAssistantError("Config entry does not support device removal") + + if (device_entry := registry.async_get(device_id)) is None: + raise HomeAssistantError("Unknown device") + + if config_entry_id not in device_entry.config_entries: + raise HomeAssistantError("Config entry not in device") + + try: + integration = await loader.async_get_integration(hass, config_entry.domain) + component = integration.get_component() + except (ImportError, loader.IntegrationNotFound) as exc: + raise HomeAssistantError("Integration not found") from exc + + if not await component.async_remove_config_entry_device( + hass, config_entry, device_entry + ): + raise HomeAssistantError( + "Failed to remove device entry, rejected by integration" + ) + + entry = registry.async_update_device( + device_id, remove_config_entry_id=config_entry_id + ) + + entry_as_dict = _entry_dict(entry) if entry else None + + connection.send_message(websocket_api.result_message(msg["id"], entry_as_dict)) + + @callback def _entry_dict(entry): """Convert entry to API format.""" diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 57c837178a4..af04ee032dc 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -175,6 +175,7 @@ class ConfigEntry: "options", "unique_id", "supports_unload", + "supports_remove_device", "pref_disable_new_entities", "pref_disable_polling", "source", @@ -257,6 +258,9 @@ class ConfigEntry: # Supports unload self.supports_unload = False + # Supports remove device + self.supports_remove_device = False + # Listeners to call on update self.update_listeners: list[ weakref.ReferenceType[UpdateListenerType] | weakref.WeakMethod @@ -287,6 +291,9 @@ class ConfigEntry: integration = await loader.async_get_integration(hass, self.domain) self.supports_unload = await support_entry_unload(hass, self.domain) + self.supports_remove_device = await support_remove_from_device( + hass, self.domain + ) try: component = integration.get_component() @@ -1615,3 +1622,10 @@ async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool: integration = await loader.async_get_integration(hass, domain) component = integration.get_component() return hasattr(component, "async_unload_entry") + + +async def support_remove_from_device(hass: HomeAssistant, domain: str) -> bool: + """Test if a domain supports being removed from a device.""" + integration = await loader.async_get_integration(hass, domain) + component = integration.get_component() + return hasattr(component, "async_remove_config_entry_device") diff --git a/tests/common.py b/tests/common.py index c8dfb3ed841..bdebc7217a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -583,6 +583,7 @@ class MockModule: async_migrate_entry=None, async_remove_entry=None, partial_manifest=None, + async_remove_config_entry_device=None, ): """Initialize the mock module.""" self.__name__ = f"homeassistant.components.{domain}" @@ -624,6 +625,9 @@ class MockModule: if async_remove_entry is not None: self.async_remove_entry = async_remove_entry + if async_remove_config_entry_device is not None: + self.async_remove_config_entry_device = async_remove_config_entry_device + def mock_manifest(self): """Generate a mock manifest to represent this module.""" return { diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index 6608bf3471d..cfc6d8d4907 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -12,7 +12,7 @@ from homeassistant.components.config import config_entries from homeassistant.config_entries import HANDLERS from homeassistant.core import callback from homeassistant.generated import config_flows -import homeassistant.helpers.config_validation as cv +from homeassistant.helpers import config_validation as cv from homeassistant.setup import async_setup_component from tests.common import ( @@ -94,6 +94,7 @@ async def test_get_entries(hass, client): "source": "bla", "state": core_ce.ConfigEntryState.NOT_LOADED.value, "supports_options": True, + "supports_remove_device": False, "supports_unload": True, "pref_disable_new_entities": False, "pref_disable_polling": False, @@ -106,6 +107,7 @@ async def test_get_entries(hass, client): "source": "bla2", "state": core_ce.ConfigEntryState.SETUP_ERROR.value, "supports_options": False, + "supports_remove_device": False, "supports_unload": False, "pref_disable_new_entities": False, "pref_disable_polling": False, @@ -118,6 +120,7 @@ async def test_get_entries(hass, client): "source": "bla3", "state": core_ce.ConfigEntryState.NOT_LOADED.value, "supports_options": False, + "supports_remove_device": False, "supports_unload": False, "pref_disable_new_entities": False, "pref_disable_polling": False, @@ -370,6 +373,7 @@ async def test_create_account(hass, client, enable_custom_integrations): "source": core_ce.SOURCE_USER, "state": core_ce.ConfigEntryState.LOADED.value, "supports_options": False, + "supports_remove_device": False, "supports_unload": False, "pref_disable_new_entities": False, "pref_disable_polling": False, @@ -443,6 +447,7 @@ async def test_two_step_flow(hass, client, enable_custom_integrations): "source": core_ce.SOURCE_USER, "state": core_ce.ConfigEntryState.LOADED.value, "supports_options": False, + "supports_remove_device": False, "supports_unload": False, "pref_disable_new_entities": False, "pref_disable_polling": False, diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index f43f9a4d8ce..f923b326100 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -3,8 +3,14 @@ import pytest from homeassistant.components.config import device_registry from homeassistant.helpers import device_registry as helpers_dr +from homeassistant.setup import async_setup_component -from tests.common import mock_device_registry +from tests.common import ( + MockConfigEntry, + MockModule, + mock_device_registry, + mock_integration, +) from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 @@ -126,3 +132,268 @@ async def test_update_device(hass, client, registry, payload_key, payload_value) assert getattr(device, payload_key) == payload_value assert isinstance(device.disabled_by, (helpers_dr.DeviceEntryDisabler, type(None))) + + +async def test_remove_config_entry_from_device(hass, hass_ws_client): + """Test removing config entry from device.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + device_registry = mock_device_registry(hass) + + can_remove = False + + async def async_remove_config_entry_device(hass, config_entry, device_entry): + return can_remove + + mock_integration( + hass, + MockModule( + "comp1", async_remove_config_entry_device=async_remove_config_entry_device + ), + ) + mock_integration( + hass, + MockModule( + "comp2", async_remove_config_entry_device=async_remove_config_entry_device + ), + ) + + entry_1 = MockConfigEntry( + domain="comp1", + title="Test 1", + source="bla", + ) + entry_1.supports_remove_device = True + entry_1.add_to_hass(hass) + + entry_2 = MockConfigEntry( + domain="comp1", + title="Test 1", + source="bla", + ) + entry_2.supports_remove_device = True + entry_2.add_to_hass(hass) + + device_registry.async_get_or_create( + config_entry_id=entry_1.entry_id, + connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + device_entry = device_registry.async_get_or_create( + config_entry_id=entry_2.entry_id, + connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + assert device_entry.config_entries == {entry_1.entry_id, entry_2.entry_id} + + # Try removing a config entry from the device, it should fail because + # async_remove_config_entry_device returns False + await ws_client.send_json( + { + "id": 5, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_1.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unknown_error" + + # Make async_remove_config_entry_device return True + can_remove = True + + # Remove the 1st config entry + await ws_client.send_json( + { + "id": 6, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_1.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert response["success"] + assert response["result"]["config_entries"] == [entry_2.entry_id] + + # Check that the config entry was removed from the device + assert device_registry.async_get(device_entry.id).config_entries == { + entry_2.entry_id + } + + # Remove the 2nd config entry + await ws_client.send_json( + { + "id": 7, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_2.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert response["success"] + assert response["result"] is None + + # This was the last config entry, the device is removed + assert not device_registry.async_get(device_entry.id) + + +async def test_remove_config_entry_from_device_fails(hass, hass_ws_client): + """Test removing config entry from device failing cases.""" + assert await async_setup_component(hass, "config", {}) + ws_client = await hass_ws_client(hass) + device_registry = mock_device_registry(hass) + + async def async_remove_config_entry_device(hass, config_entry, device_entry): + return True + + mock_integration( + hass, + MockModule("comp1"), + ) + mock_integration( + hass, + MockModule( + "comp2", async_remove_config_entry_device=async_remove_config_entry_device + ), + ) + + entry_1 = MockConfigEntry( + domain="comp1", + title="Test 1", + source="bla", + ) + entry_1.add_to_hass(hass) + + entry_2 = MockConfigEntry( + domain="comp2", + title="Test 1", + source="bla", + ) + entry_2.supports_remove_device = True + entry_2.add_to_hass(hass) + + entry_3 = MockConfigEntry( + domain="comp3", + title="Test 1", + source="bla", + ) + entry_3.supports_remove_device = True + entry_3.add_to_hass(hass) + + device_registry.async_get_or_create( + config_entry_id=entry_1.entry_id, + connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + device_registry.async_get_or_create( + config_entry_id=entry_2.entry_id, + connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + device_entry = device_registry.async_get_or_create( + config_entry_id=entry_3.entry_id, + connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + assert device_entry.config_entries == { + entry_1.entry_id, + entry_2.entry_id, + entry_3.entry_id, + } + + fake_entry_id = "abc123" + assert entry_1.entry_id != fake_entry_id + fake_device_id = "abc123" + assert device_entry.id != fake_device_id + + # Try removing a non existing config entry from the device + await ws_client.send_json( + { + "id": 5, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": fake_entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unknown_error" + assert response["error"]["message"] == "Unknown config entry" + + # Try removing a config entry which does not support removal from the device + await ws_client.send_json( + { + "id": 6, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_1.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unknown_error" + assert ( + response["error"]["message"] == "Config entry does not support device removal" + ) + + # Try removing a config entry from a device which does not exist + await ws_client.send_json( + { + "id": 7, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_2.entry_id, + "device_id": fake_device_id, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unknown_error" + assert response["error"]["message"] == "Unknown device" + + # Try removing a config entry from a device which it's not connected to + await ws_client.send_json( + { + "id": 8, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_2.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert response["success"] + assert set(response["result"]["config_entries"]) == { + entry_1.entry_id, + entry_3.entry_id, + } + + await ws_client.send_json( + { + "id": 9, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_2.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unknown_error" + assert response["error"]["message"] == "Config entry not in device" + + # Try removing a config entry which can't be loaded from a device - allowed + await ws_client.send_json( + { + "id": 10, + "type": "config/device_registry/remove_config_entry", + "config_entry_id": entry_3.entry_id, + "device_id": device_entry.id, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unknown_error" + assert response["error"]["message"] == "Integration not found"