diff --git a/homeassistant/components/derivative/__init__.py b/homeassistant/components/derivative/__init__.py index 5eb499b0efd..6d539817875 100644 --- a/homeassistant/components/derivative/__init__.py +++ b/homeassistant/components/derivative/__init__.py @@ -29,6 +29,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: options={**entry.options, CONF_SOURCE: source_entity_id}, ) + async def source_entity_removed() -> None: + # The source entity has been removed, we need to clean the device links. + async_remove_stale_devices_links_keep_entity_device(hass, entry.entry_id, None) + entity_registry = er.async_get(hass) entry.async_on_unload( async_handle_source_entity_changes( @@ -42,6 +46,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass, entry.options[CONF_SOURCE] ), source_entity_id_or_uuid=entry.options[CONF_SOURCE], + source_entity_removed=source_entity_removed, ) ) await hass.config_entries.async_forward_entry_setups(entry, (Platform.SENSOR,)) diff --git a/homeassistant/components/switch_as_x/__init__.py b/homeassistant/components/switch_as_x/__init__.py index 6f21b032da3..6bd2c4e6482 100644 --- a/homeassistant/components/switch_as_x/__init__.py +++ b/homeassistant/components/switch_as_x/__init__.py @@ -60,6 +60,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: options={**entry.options, CONF_ENTITY_ID: source_entity_id}, ) + async def source_entity_removed() -> None: + # The source entity has been removed, we remove the config entry because + # switch_as_x does not allow replacing the wrapped entity. + await hass.config_entries.async_remove(entry.entry_id) + entry.async_on_unload( async_handle_source_entity_changes( hass, @@ -70,6 +75,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: set_source_entity_id_or_uuid=set_source_entity_id_or_uuid, source_device_id=async_add_to_device(hass, entry, entity_id), source_entity_id_or_uuid=entry.options[CONF_ENTITY_ID], + source_entity_removed=source_entity_removed, ) ) entry.async_on_unload(entry.add_update_listener(config_entry_update_listener)) diff --git a/homeassistant/helpers/device.py b/homeassistant/helpers/device.py index a7d888900b1..f1404bb068b 100644 --- a/homeassistant/helpers/device.py +++ b/homeassistant/helpers/device.py @@ -62,7 +62,7 @@ def async_device_info_to_link_from_device_id( def async_remove_stale_devices_links_keep_entity_device( hass: HomeAssistant, entry_id: str, - source_entity_id_or_uuid: str, + source_entity_id_or_uuid: str | None, ) -> None: """Remove entry_id from all devices except that of source_entity_id_or_uuid. @@ -73,7 +73,9 @@ def async_remove_stale_devices_links_keep_entity_device( async_remove_stale_devices_links_keep_current_device( hass=hass, entry_id=entry_id, - current_device_id=async_entity_id_to_device_id(hass, source_entity_id_or_uuid), + current_device_id=async_entity_id_to_device_id(hass, source_entity_id_or_uuid) + if source_entity_id_or_uuid + else None, ) diff --git a/homeassistant/helpers/helper_integration.py b/homeassistant/helpers/helper_integration.py index 4f39ef4c843..37aa246178e 100644 --- a/homeassistant/helpers/helper_integration.py +++ b/homeassistant/helpers/helper_integration.py @@ -2,7 +2,8 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Coroutine +from typing import Any from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, valid_entity_id @@ -18,6 +19,7 @@ def async_handle_source_entity_changes( set_source_entity_id_or_uuid: Callable[[str], None], source_device_id: str | None, source_entity_id_or_uuid: str, + source_entity_removed: Callable[[], Coroutine[Any, Any, None]], ) -> CALLBACK_TYPE: """Handle changes to a helper entity's source entity. @@ -34,6 +36,14 @@ def async_handle_source_entity_changes( - Source entity removed from the device: The helper entity is updated to link to no device, and the helper config entry removed from the old device. Then the helper config entry is reloaded. + + :param get_helper_entity: A function which returns the helper entity's entity ID, + or None if the helper entity does not exist. + :param set_source_entity_id_or_uuid: A function which updates the source entity + ID or UUID, e.g., in the helper config entry options. + :param source_entity_removed: A function which is called when the source entity + is removed. This can be used to clean up any resources related to the source + entity or ask the user to select a new source entity. """ async def async_registry_updated( @@ -44,7 +54,7 @@ def async_handle_source_entity_changes( data = event.data if data["action"] == "remove": - await hass.config_entries.async_remove(helper_config_entry_id) + await source_entity_removed() if data["action"] != "update": return diff --git a/tests/components/derivative/test_init.py b/tests/components/derivative/test_init.py index f75d5940da7..d237703eb2e 100644 --- a/tests/components/derivative/test_init.py +++ b/tests/components/derivative/test_init.py @@ -268,17 +268,17 @@ async def test_async_handle_source_entity_changes_source_entity_removed( ) await hass.async_block_till_done() await hass.async_block_till_done() - mock_unload_entry.assert_called_once() + mock_unload_entry.assert_not_called() # Check that the derivative config entry is removed from the device sensor_device = device_registry.async_get(sensor_device.id) assert derivative_config_entry.entry_id not in sensor_device.config_entries - # Check that the derivative config entry is removed - assert derivative_config_entry.entry_id not in hass.config_entries.async_entry_ids() + # Check that the derivative config entry is not removed + assert derivative_config_entry.entry_id in hass.config_entries.async_entry_ids() # Check we got the expected events - assert events == ["remove"] + assert events == ["update"] async def test_async_handle_source_entity_changes_source_entity_removed_from_device( diff --git a/tests/helpers/test_helper_integration.py b/tests/helpers/test_helper_integration.py index 25d490c27bb..12433894dc7 100644 --- a/tests/helpers/test_helper_integration.py +++ b/tests/helpers/test_helper_integration.py @@ -132,11 +132,17 @@ def async_unload_entry() -> AsyncMock: @pytest.fixture -def set_source_entity_id_or_uuid() -> AsyncMock: - """Fixture to mock async_unload_entry.""" +def set_source_entity_id_or_uuid() -> Mock: + """Fixture to mock set_source_entity_id_or_uuid.""" return Mock() +@pytest.fixture +def source_entity_removed() -> AsyncMock: + """Fixture to mock source_entity_removed.""" + return AsyncMock() + + @pytest.fixture def mock_helper_integration( hass: HomeAssistant, @@ -146,6 +152,7 @@ def mock_helper_integration( async_remove_entry: AsyncMock, async_unload_entry: AsyncMock, set_source_entity_id_or_uuid: Mock, + source_entity_removed: AsyncMock, ) -> None: """Mock the helper integration.""" @@ -164,6 +171,7 @@ def mock_helper_integration( set_source_entity_id_or_uuid=set_source_entity_id_or_uuid, source_device_id=source_entity_entry.device_id, source_entity_id_or_uuid=helper_config_entry.options["source"], + source_entity_removed=source_entity_removed, ) return True @@ -206,6 +214,7 @@ async def test_async_handle_source_entity_changes_source_entity_removed( async_remove_entry: AsyncMock, async_unload_entry: AsyncMock, set_source_entity_id_or_uuid: Mock, + source_entity_removed: AsyncMock, ) -> None: """Test the helper config entry is removed when the source entity is removed.""" # Add the helper config entry to the source device @@ -238,20 +247,21 @@ async def test_async_handle_source_entity_changes_source_entity_removed( await hass.async_block_till_done() await hass.async_block_till_done() - # Check that the helper config entry is unloaded and removed - async_unload_entry.assert_called_once() - async_remove_entry.assert_called_once() + # Check that the source_entity_removed callback was called + source_entity_removed.assert_called_once() + async_unload_entry.assert_not_called() + async_remove_entry.assert_not_called() set_source_entity_id_or_uuid.assert_not_called() - # Check that the helper config entry is removed from the device + # Check that the helper config entry is not removed from the device source_device = device_registry.async_get(source_device.id) - assert helper_config_entry.entry_id not in source_device.config_entries + assert helper_config_entry.entry_id in source_device.config_entries - # Check that the helper config entry is removed - assert helper_config_entry.entry_id not in hass.config_entries.async_entry_ids() + # Check that the helper config entry is not removed + assert helper_config_entry.entry_id in hass.config_entries.async_entry_ids() # Check we got the expected events - assert events == ["remove"] + assert events == [] @pytest.mark.parametrize("use_entity_registry_id", [True, False])