diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 69696561303..1036c02fd0d 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -378,6 +378,17 @@ class ConfigEntry: self.state = ENTRY_STATE_FAILED_UNLOAD return False + async def async_remove(self, hass: HomeAssistant) -> None: + """Invoke remove callback on component.""" + component = getattr(hass.components, self.domain) + if not hasattr(component, 'async_remove_entry'): + return + try: + await component.async_remove_entry(hass, self) + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Error calling entry remove callback %s for %s', + self.title, component.DOMAIN) + async def async_migrate(self, hass: HomeAssistant) -> bool: """Migrate an entry. @@ -499,6 +510,8 @@ class ConfigEntries: else: unload_success = await self.async_unload(entry_id) + await entry.async_remove(self.hass) + self._entries.remove(entry) self._async_schedule_save() diff --git a/tests/common.py b/tests/common.py index a55546da73b..8681db1b4f3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -452,7 +452,7 @@ class MockModule: requirements=None, config_schema=None, platform_schema=None, platform_schema_base=None, async_setup=None, async_setup_entry=None, async_unload_entry=None, - async_migrate_entry=None): + async_migrate_entry=None, async_remove_entry=None): """Initialize the mock module.""" self.__name__ = 'homeassistant.components.{}'.format(domain) self.DOMAIN = domain @@ -487,6 +487,9 @@ class MockModule: if async_migrate_entry is not None: self.async_migrate_entry = async_migrate_entry + if async_remove_entry is not None: + self.async_remove_entry = async_remove_entry + class MockPlatform: """Provide a fake platform.""" diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index e7a5b763796..324db971583 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -173,6 +173,9 @@ async def test_remove_entry(hass, manager): assert result return result + mock_remove_entry = MagicMock( + side_effect=lambda *args, **kwargs: mock_coro()) + entity = MockEntity( unique_id='1234', name='Test Entity', @@ -185,7 +188,8 @@ async def test_remove_entry(hass, manager): loader.set_component(hass, 'test', MockModule( 'test', async_setup_entry=mock_setup_entry, - async_unload_entry=mock_unload_entry + async_unload_entry=mock_unload_entry, + async_remove_entry=mock_remove_entry )) loader.set_component( hass, 'light.test', @@ -227,6 +231,9 @@ async def test_remove_entry(hass, manager): 'require_restart': False } + # Check the remove callback was invoked. + assert mock_remove_entry.call_count == 1 + # Check that config entry was removed. assert [item.entry_id for item in manager.async_entries()] == \ ['test1', 'test3'] @@ -241,6 +248,43 @@ async def test_remove_entry(hass, manager): assert entity_entry.config_entry_id is None +async def test_remove_entry_handles_callback_error(hass, manager): + """Test that exceptions in the remove callback are handled.""" + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + mock_unload_entry = MagicMock(return_value=mock_coro(True)) + mock_remove_entry = MagicMock( + side_effect=lambda *args, **kwargs: mock_coro()) + loader.set_component(hass, 'test', MockModule( + 'test', + async_setup_entry=mock_setup_entry, + async_unload_entry=mock_unload_entry, + async_remove_entry=mock_remove_entry + )) + entry = MockConfigEntry( + domain='test', + entry_id='test1', + ) + entry.add_to_manager(manager) + # Check all config entries exist + assert [item.entry_id for item in manager.async_entries()] == \ + ['test1'] + # Setup entry + await entry.async_setup(hass) + await hass.async_block_till_done() + + # Remove entry + result = await manager.async_remove('test1') + await hass.async_block_till_done() + # Check that unload went well and so no need to restart + assert result == { + 'require_restart': False + } + # Check the remove callback was invoked. + assert mock_remove_entry.call_count == 1 + # Check that config entry was removed. + assert [item.entry_id for item in manager.async_entries()] == [] + + @asyncio.coroutine def test_remove_entry_raises(hass, manager): """Test if a component raises while removing entry."""