Add WS API for removing a config entry from a device (#66188)

* Add WS API for removing a config entry from a device

* Address review comments

* Address review comments

* Remove entity cleanup from ConfigEntries

* Update + add tests

* Improve comments in test

* Add negative test

* Refactor according to review comments

* Add back async_remove_config_entry_device

* Remove unnecessary error handling

* Simplify error handling
This commit is contained in:
Erik Montnemery 2022-02-21 10:11:18 +01:00 committed by GitHub
parent 7a39c769f0
commit c496748125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 362 additions and 17 deletions

View File

@ -387,6 +387,7 @@ def entry_json(entry: config_entries.ConfigEntry) -> dict:
"source": entry.source, "source": entry.source,
"state": entry.state.value, "state": entry.state.value,
"supports_options": supports_options, "supports_options": supports_options,
"supports_remove_device": entry.supports_remove_device,
"supports_unload": entry.supports_unload, "supports_unload": entry.supports_unload,
"pref_disable_new_entities": entry.pref_disable_new_entities, "pref_disable_new_entities": entry.pref_disable_new_entities,
"pref_disable_polling": entry.pref_disable_polling, "pref_disable_polling": entry.pref_disable_polling,

View File

@ -1,16 +1,12 @@
"""HTTP views to interact with the device registry.""" """HTTP views to interact with the device registry."""
import voluptuous as vol import voluptuous as vol
from homeassistant import loader
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import ( from homeassistant.components.websocket_api.decorators import require_admin
async_response, from homeassistant.core import HomeAssistant, callback
require_admin, from homeassistant.exceptions import HomeAssistantError
) from homeassistant.helpers.device_registry import DeviceEntryDisabler, async_get
from homeassistant.core import callback
from homeassistant.helpers.device_registry import (
DeviceEntryDisabler,
async_get_registry,
)
WS_TYPE_LIST = "config/device_registry/list" WS_TYPE_LIST = "config/device_registry/list"
SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
@ -39,13 +35,16 @@ async def async_setup(hass):
websocket_api.async_register_command( websocket_api.async_register_command(
hass, WS_TYPE_UPDATE, websocket_update_device, SCHEMA_WS_UPDATE hass, WS_TYPE_UPDATE, websocket_update_device, SCHEMA_WS_UPDATE
) )
websocket_api.async_register_command(
hass, websocket_remove_config_entry_from_device
)
return True return True
@async_response @callback
async def websocket_list_devices(hass, connection, msg): def websocket_list_devices(hass, connection, msg):
"""Handle list devices command.""" """Handle list devices command."""
registry = await async_get_registry(hass) registry = async_get(hass)
connection.send_message( connection.send_message(
websocket_api.result_message( websocket_api.result_message(
msg["id"], [_entry_dict(entry) for entry in registry.devices.values()] msg["id"], [_entry_dict(entry) for entry in registry.devices.values()]
@ -54,10 +53,10 @@ async def websocket_list_devices(hass, connection, msg):
@require_admin @require_admin
@async_response @callback
async def websocket_update_device(hass, connection, msg): def websocket_update_device(hass, connection, msg):
"""Handle update area websocket command.""" """Handle update area websocket command."""
registry = await async_get_registry(hass) registry = async_get(hass)
msg.pop("type") msg.pop("type")
msg_id = msg.pop("id") msg_id = msg.pop("id")
@ -70,6 +69,57 @@ async def websocket_update_device(hass, connection, msg):
connection.send_message(websocket_api.result_message(msg_id, _entry_dict(entry))) connection.send_message(websocket_api.result_message(msg_id, _entry_dict(entry)))
@websocket_api.require_admin
@websocket_api.websocket_command(
{
"type": "config/device_registry/remove_config_entry",
"device_id": str,
"config_entry_id": str,
}
)
@websocket_api.async_response
async def websocket_remove_config_entry_from_device(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Remove config entry from a device."""
registry = async_get(hass)
config_entry_id = msg["config_entry_id"]
device_id = msg["device_id"]
if (config_entry := hass.config_entries.async_get_entry(config_entry_id)) is None:
raise HomeAssistantError("Unknown config entry")
if not config_entry.supports_remove_device:
raise HomeAssistantError("Config entry does not support device removal")
if (device_entry := registry.async_get(device_id)) is None:
raise HomeAssistantError("Unknown device")
if config_entry_id not in device_entry.config_entries:
raise HomeAssistantError("Config entry not in device")
try:
integration = await loader.async_get_integration(hass, config_entry.domain)
component = integration.get_component()
except (ImportError, loader.IntegrationNotFound) as exc:
raise HomeAssistantError("Integration not found") from exc
if not await component.async_remove_config_entry_device(
hass, config_entry, device_entry
):
raise HomeAssistantError(
"Failed to remove device entry, rejected by integration"
)
entry = registry.async_update_device(
device_id, remove_config_entry_id=config_entry_id
)
entry_as_dict = _entry_dict(entry) if entry else None
connection.send_message(websocket_api.result_message(msg["id"], entry_as_dict))
@callback @callback
def _entry_dict(entry): def _entry_dict(entry):
"""Convert entry to API format.""" """Convert entry to API format."""

View File

@ -175,6 +175,7 @@ class ConfigEntry:
"options", "options",
"unique_id", "unique_id",
"supports_unload", "supports_unload",
"supports_remove_device",
"pref_disable_new_entities", "pref_disable_new_entities",
"pref_disable_polling", "pref_disable_polling",
"source", "source",
@ -257,6 +258,9 @@ class ConfigEntry:
# Supports unload # Supports unload
self.supports_unload = False self.supports_unload = False
# Supports remove device
self.supports_remove_device = False
# Listeners to call on update # Listeners to call on update
self.update_listeners: list[ self.update_listeners: list[
weakref.ReferenceType[UpdateListenerType] | weakref.WeakMethod weakref.ReferenceType[UpdateListenerType] | weakref.WeakMethod
@ -287,6 +291,9 @@ class ConfigEntry:
integration = await loader.async_get_integration(hass, self.domain) integration = await loader.async_get_integration(hass, self.domain)
self.supports_unload = await support_entry_unload(hass, self.domain) self.supports_unload = await support_entry_unload(hass, self.domain)
self.supports_remove_device = await support_remove_from_device(
hass, self.domain
)
try: try:
component = integration.get_component() component = integration.get_component()
@ -1615,3 +1622,10 @@ async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool:
integration = await loader.async_get_integration(hass, domain) integration = await loader.async_get_integration(hass, domain)
component = integration.get_component() component = integration.get_component()
return hasattr(component, "async_unload_entry") return hasattr(component, "async_unload_entry")
async def support_remove_from_device(hass: HomeAssistant, domain: str) -> bool:
"""Test if a domain supports being removed from a device."""
integration = await loader.async_get_integration(hass, domain)
component = integration.get_component()
return hasattr(component, "async_remove_config_entry_device")

View File

@ -583,6 +583,7 @@ class MockModule:
async_migrate_entry=None, async_migrate_entry=None,
async_remove_entry=None, async_remove_entry=None,
partial_manifest=None, partial_manifest=None,
async_remove_config_entry_device=None,
): ):
"""Initialize the mock module.""" """Initialize the mock module."""
self.__name__ = f"homeassistant.components.{domain}" self.__name__ = f"homeassistant.components.{domain}"
@ -624,6 +625,9 @@ class MockModule:
if async_remove_entry is not None: if async_remove_entry is not None:
self.async_remove_entry = async_remove_entry self.async_remove_entry = async_remove_entry
if async_remove_config_entry_device is not None:
self.async_remove_config_entry_device = async_remove_config_entry_device
def mock_manifest(self): def mock_manifest(self):
"""Generate a mock manifest to represent this module.""" """Generate a mock manifest to represent this module."""
return { return {

View File

@ -12,7 +12,7 @@ from homeassistant.components.config import config_entries
from homeassistant.config_entries import HANDLERS from homeassistant.config_entries import HANDLERS
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.generated import config_flows from homeassistant.generated import config_flows
import homeassistant.helpers.config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
@ -94,6 +94,7 @@ async def test_get_entries(hass, client):
"source": "bla", "source": "bla",
"state": core_ce.ConfigEntryState.NOT_LOADED.value, "state": core_ce.ConfigEntryState.NOT_LOADED.value,
"supports_options": True, "supports_options": True,
"supports_remove_device": False,
"supports_unload": True, "supports_unload": True,
"pref_disable_new_entities": False, "pref_disable_new_entities": False,
"pref_disable_polling": False, "pref_disable_polling": False,
@ -106,6 +107,7 @@ async def test_get_entries(hass, client):
"source": "bla2", "source": "bla2",
"state": core_ce.ConfigEntryState.SETUP_ERROR.value, "state": core_ce.ConfigEntryState.SETUP_ERROR.value,
"supports_options": False, "supports_options": False,
"supports_remove_device": False,
"supports_unload": False, "supports_unload": False,
"pref_disable_new_entities": False, "pref_disable_new_entities": False,
"pref_disable_polling": False, "pref_disable_polling": False,
@ -118,6 +120,7 @@ async def test_get_entries(hass, client):
"source": "bla3", "source": "bla3",
"state": core_ce.ConfigEntryState.NOT_LOADED.value, "state": core_ce.ConfigEntryState.NOT_LOADED.value,
"supports_options": False, "supports_options": False,
"supports_remove_device": False,
"supports_unload": False, "supports_unload": False,
"pref_disable_new_entities": False, "pref_disable_new_entities": False,
"pref_disable_polling": False, "pref_disable_polling": False,
@ -370,6 +373,7 @@ async def test_create_account(hass, client, enable_custom_integrations):
"source": core_ce.SOURCE_USER, "source": core_ce.SOURCE_USER,
"state": core_ce.ConfigEntryState.LOADED.value, "state": core_ce.ConfigEntryState.LOADED.value,
"supports_options": False, "supports_options": False,
"supports_remove_device": False,
"supports_unload": False, "supports_unload": False,
"pref_disable_new_entities": False, "pref_disable_new_entities": False,
"pref_disable_polling": False, "pref_disable_polling": False,
@ -443,6 +447,7 @@ async def test_two_step_flow(hass, client, enable_custom_integrations):
"source": core_ce.SOURCE_USER, "source": core_ce.SOURCE_USER,
"state": core_ce.ConfigEntryState.LOADED.value, "state": core_ce.ConfigEntryState.LOADED.value,
"supports_options": False, "supports_options": False,
"supports_remove_device": False,
"supports_unload": False, "supports_unload": False,
"pref_disable_new_entities": False, "pref_disable_new_entities": False,
"pref_disable_polling": False, "pref_disable_polling": False,

View File

@ -3,8 +3,14 @@ import pytest
from homeassistant.components.config import device_registry from homeassistant.components.config import device_registry
from homeassistant.helpers import device_registry as helpers_dr from homeassistant.helpers import device_registry as helpers_dr
from homeassistant.setup import async_setup_component
from tests.common import mock_device_registry from tests.common import (
MockConfigEntry,
MockModule,
mock_device_registry,
mock_integration,
)
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
@ -126,3 +132,268 @@ async def test_update_device(hass, client, registry, payload_key, payload_value)
assert getattr(device, payload_key) == payload_value assert getattr(device, payload_key) == payload_value
assert isinstance(device.disabled_by, (helpers_dr.DeviceEntryDisabler, type(None))) assert isinstance(device.disabled_by, (helpers_dr.DeviceEntryDisabler, type(None)))
async def test_remove_config_entry_from_device(hass, hass_ws_client):
"""Test removing config entry from device."""
assert await async_setup_component(hass, "config", {})
ws_client = await hass_ws_client(hass)
device_registry = mock_device_registry(hass)
can_remove = False
async def async_remove_config_entry_device(hass, config_entry, device_entry):
return can_remove
mock_integration(
hass,
MockModule(
"comp1", async_remove_config_entry_device=async_remove_config_entry_device
),
)
mock_integration(
hass,
MockModule(
"comp2", async_remove_config_entry_device=async_remove_config_entry_device
),
)
entry_1 = MockConfigEntry(
domain="comp1",
title="Test 1",
source="bla",
)
entry_1.supports_remove_device = True
entry_1.add_to_hass(hass)
entry_2 = MockConfigEntry(
domain="comp1",
title="Test 1",
source="bla",
)
entry_2.supports_remove_device = True
entry_2.add_to_hass(hass)
device_registry.async_get_or_create(
config_entry_id=entry_1.entry_id,
connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
device_entry = device_registry.async_get_or_create(
config_entry_id=entry_2.entry_id,
connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
assert device_entry.config_entries == {entry_1.entry_id, entry_2.entry_id}
# Try removing a config entry from the device, it should fail because
# async_remove_config_entry_device returns False
await ws_client.send_json(
{
"id": 5,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_1.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
# Make async_remove_config_entry_device return True
can_remove = True
# Remove the 1st config entry
await ws_client.send_json(
{
"id": 6,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_1.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert response["success"]
assert response["result"]["config_entries"] == [entry_2.entry_id]
# Check that the config entry was removed from the device
assert device_registry.async_get(device_entry.id).config_entries == {
entry_2.entry_id
}
# Remove the 2nd config entry
await ws_client.send_json(
{
"id": 7,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_2.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] is None
# This was the last config entry, the device is removed
assert not device_registry.async_get(device_entry.id)
async def test_remove_config_entry_from_device_fails(hass, hass_ws_client):
"""Test removing config entry from device failing cases."""
assert await async_setup_component(hass, "config", {})
ws_client = await hass_ws_client(hass)
device_registry = mock_device_registry(hass)
async def async_remove_config_entry_device(hass, config_entry, device_entry):
return True
mock_integration(
hass,
MockModule("comp1"),
)
mock_integration(
hass,
MockModule(
"comp2", async_remove_config_entry_device=async_remove_config_entry_device
),
)
entry_1 = MockConfigEntry(
domain="comp1",
title="Test 1",
source="bla",
)
entry_1.add_to_hass(hass)
entry_2 = MockConfigEntry(
domain="comp2",
title="Test 1",
source="bla",
)
entry_2.supports_remove_device = True
entry_2.add_to_hass(hass)
entry_3 = MockConfigEntry(
domain="comp3",
title="Test 1",
source="bla",
)
entry_3.supports_remove_device = True
entry_3.add_to_hass(hass)
device_registry.async_get_or_create(
config_entry_id=entry_1.entry_id,
connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
device_registry.async_get_or_create(
config_entry_id=entry_2.entry_id,
connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
device_entry = device_registry.async_get_or_create(
config_entry_id=entry_3.entry_id,
connections={(helpers_dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
assert device_entry.config_entries == {
entry_1.entry_id,
entry_2.entry_id,
entry_3.entry_id,
}
fake_entry_id = "abc123"
assert entry_1.entry_id != fake_entry_id
fake_device_id = "abc123"
assert device_entry.id != fake_device_id
# Try removing a non existing config entry from the device
await ws_client.send_json(
{
"id": 5,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": fake_entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
assert response["error"]["message"] == "Unknown config entry"
# Try removing a config entry which does not support removal from the device
await ws_client.send_json(
{
"id": 6,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_1.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
assert (
response["error"]["message"] == "Config entry does not support device removal"
)
# Try removing a config entry from a device which does not exist
await ws_client.send_json(
{
"id": 7,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_2.entry_id,
"device_id": fake_device_id,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
assert response["error"]["message"] == "Unknown device"
# Try removing a config entry from a device which it's not connected to
await ws_client.send_json(
{
"id": 8,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_2.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert response["success"]
assert set(response["result"]["config_entries"]) == {
entry_1.entry_id,
entry_3.entry_id,
}
await ws_client.send_json(
{
"id": 9,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_2.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
assert response["error"]["message"] == "Config entry not in device"
# Try removing a config entry which can't be loaded from a device - allowed
await ws_client.send_json(
{
"id": 10,
"type": "config/device_registry/remove_config_entry",
"config_entry_id": entry_3.entry_id,
"device_id": device_entry.id,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
assert response["error"]["message"] == "Integration not found"