Allow config entry reloading (#21502)

* Allow config entry reloading

* Fix duplicate test name

* Add comment

* fix typing
This commit is contained in:
Paulus Schoutsen 2019-02-28 20:27:20 -08:00 committed by Andrew Sayre
parent aa30ac52ea
commit ee4be13bda
3 changed files with 317 additions and 44 deletions

View File

@ -117,7 +117,7 @@ async def async_from_config_dict(config: Dict[str, Any],
hass, config, core_config.get(conf_util.CONF_PACKAGES, {})) hass, config, core_config.get(conf_util.CONF_PACKAGES, {}))
hass.config_entries = config_entries.ConfigEntries(hass, config) hass.config_entries = config_entries.ConfigEntries(hass, config)
await hass.config_entries.async_load() await hass.config_entries.async_initialize()
# Filter out the repeating and common config section [homeassistant] # Filter out the repeating and common config section [homeassistant]
components = set(key.split(' ')[0] for key in config.keys() components = set(key.split(' ')[0] for key in config.keys()

View File

@ -119,6 +119,7 @@ should follow the same return values as a normal step.
If the result of the step is to show a form, the user will be able to continue If the result of the step is to show a form, the user will be able to continue
the flow from the config panel. the flow from the config panel.
""" """
import asyncio
import logging import logging
import functools import functools
import uuid import uuid
@ -205,6 +206,11 @@ ENTRY_STATE_NOT_LOADED = 'not_loaded'
# An error occurred when trying to unload the entry # An error occurred when trying to unload the entry
ENTRY_STATE_FAILED_UNLOAD = 'failed_unload' ENTRY_STATE_FAILED_UNLOAD = 'failed_unload'
UNRECOVERABLE_STATES = (
ENTRY_STATE_MIGRATION_ERROR,
ENTRY_STATE_FAILED_UNLOAD,
)
DISCOVERY_NOTIFICATION_ID = 'config_entry_discovery' DISCOVERY_NOTIFICATION_ID = 'config_entry_discovery'
DISCOVERY_SOURCES = ( DISCOVERY_SOURCES = (
SOURCE_DISCOVERY, SOURCE_DISCOVERY,
@ -221,6 +227,18 @@ CONN_CLASS_ASSUMED = 'assumed'
CONN_CLASS_UNKNOWN = 'unknown' CONN_CLASS_UNKNOWN = 'unknown'
class ConfigError(HomeAssistantError):
"""Error while configuring an account."""
class UnknownEntry(ConfigError):
"""Unknown entry specified."""
class OperationNotAllowed(ConfigError):
"""Raised when a config entry operation is not allowed."""
class ConfigEntry: class ConfigEntry:
"""Hold a configuration entry.""" """Hold a configuration entry."""
@ -228,7 +246,7 @@ class ConfigEntry:
'source', 'connection_class', 'state', '_setup_lock', 'source', 'connection_class', 'state', '_setup_lock',
'update_listeners', '_async_cancel_retry_setup') 'update_listeners', '_async_cancel_retry_setup')
def __init__(self, version: str, domain: str, title: str, data: dict, def __init__(self, version: int, domain: str, title: str, data: dict,
source: str, connection_class: str, source: str, connection_class: str,
options: Optional[dict] = None, options: Optional[dict] = None,
entry_id: Optional[str] = None, entry_id: Optional[str] = None,
@ -283,7 +301,7 @@ class ConfigEntry:
result = await component.async_setup_entry(hass, self) result = await component.async_setup_entry(hass, self)
if not isinstance(result, bool): if not isinstance(result, bool):
_LOGGER.error('%s.async_config_entry did not return boolean', _LOGGER.error('%s.async_setup_entry did not return boolean',
component.DOMAIN) component.DOMAIN)
result = False result = False
except ConfigEntryNotReady: except ConfigEntryNotReady:
@ -316,7 +334,7 @@ class ConfigEntry:
else: else:
self.state = ENTRY_STATE_SETUP_ERROR self.state = ENTRY_STATE_SETUP_ERROR
async def async_unload(self, hass, *, component=None): async def async_unload(self, hass, *, component=None) -> bool:
"""Unload an entry. """Unload an entry.
Returns if unload is possible and was successful. Returns if unload is possible and was successful.
@ -325,17 +343,22 @@ class ConfigEntry:
component = getattr(hass.components, self.domain) component = getattr(hass.components, self.domain)
if component.DOMAIN == self.domain: if component.DOMAIN == self.domain:
if self._async_cancel_retry_setup is not None: if self.state in UNRECOVERABLE_STATES:
self._async_cancel_retry_setup() return False
self.state = ENTRY_STATE_NOT_LOADED
return True
if self.state != ENTRY_STATE_LOADED: if self.state != ENTRY_STATE_LOADED:
if self._async_cancel_retry_setup is not None:
self._async_cancel_retry_setup()
self._async_cancel_retry_setup = None
self.state = ENTRY_STATE_NOT_LOADED
return True return True
supports_unload = hasattr(component, 'async_unload_entry') supports_unload = hasattr(component, 'async_unload_entry')
if not supports_unload: if not supports_unload:
if component.DOMAIN == self.domain:
self.state = ENTRY_STATE_FAILED_UNLOAD
return False return False
try: try:
@ -420,14 +443,6 @@ class ConfigEntry:
} }
class ConfigError(HomeAssistantError):
"""Error while configuring an account."""
class UnknownEntry(ConfigError):
"""Unknown entry specified."""
class ConfigEntries: class ConfigEntries:
"""Manage the configuration entries. """Manage the configuration entries.
@ -474,34 +489,33 @@ class ConfigEntries:
async def async_remove(self, entry_id): async def async_remove(self, entry_id):
"""Remove an entry.""" """Remove an entry."""
found = None entry = self.async_get_entry(entry_id)
for index, entry in enumerate(self._entries):
if entry.entry_id == entry_id:
found = index
break
if found is None: if entry is None:
raise UnknownEntry raise UnknownEntry
entry = self._entries.pop(found) if entry.state in UNRECOVERABLE_STATES:
unload_success = entry.state != ENTRY_STATE_FAILED_UNLOAD
else:
unload_success = await self.async_unload(entry_id)
self._entries.remove(entry)
self._async_schedule_save() self._async_schedule_save()
unloaded = await entry.async_unload(self.hass) dev_reg, ent_reg = await asyncio.gather(
self.hass.helpers.device_registry.async_get_registry(),
self.hass.helpers.entity_registry.async_get_registry(),
)
device_registry = await \ dev_reg.async_clear_config_entry(entry_id)
self.hass.helpers.device_registry.async_get_registry() ent_reg.async_clear_config_entry(entry_id)
device_registry.async_clear_config_entry(entry_id)
entity_registry = await \
self.hass.helpers.entity_registry.async_get_registry()
entity_registry.async_clear_config_entry(entry_id)
return { return {
'require_restart': not unloaded 'require_restart': not unload_success
} }
async def async_load(self) -> None: async def async_initialize(self) -> None:
"""Handle loading the config.""" """Initialize config entry config."""
# Migrating for config entries stored before 0.73 # Migrating for config entries stored before 0.73
config = await self.hass.helpers.storage.async_migrator( config = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_CONFIG), self._store, self.hass.config.path(PATH_CONFIG), self._store,
@ -527,6 +541,56 @@ class ConfigEntries:
options=entry.get('options')) options=entry.get('options'))
for entry in config['entries']] for entry in config['entries']]
async def async_setup(self, entry_id: str) -> bool:
"""Set up a config entry.
Return True if entry has been successfully loaded.
"""
entry = self.async_get_entry(entry_id)
if entry is None:
raise UnknownEntry
if entry.state != ENTRY_STATE_NOT_LOADED:
raise OperationNotAllowed
# Setup Component if not set up yet
if entry.domain in self.hass.config.components:
await entry.async_setup(self.hass)
else:
# Setting up the component will set up all its config entries
result = await async_setup_component(
self.hass, entry.domain, self._hass_config)
if not result:
return result
return entry.state == ENTRY_STATE_LOADED
async def async_unload(self, entry_id: str) -> bool:
"""Unload a config entry."""
entry = self.async_get_entry(entry_id)
if entry is None:
raise UnknownEntry
if entry.state in UNRECOVERABLE_STATES:
raise OperationNotAllowed
return await entry.async_unload(self.hass)
async def async_reload(self, entry_id: str) -> bool:
"""Reload an entry.
If an entry was not loaded, will just load.
"""
unload_result = await self.async_unload(entry_id)
if not unload_result:
return unload_result
return await self.async_setup(entry_id)
@callback @callback
def async_update_entry(self, entry, *, data=_UNDEF, options=_UNDEF): def async_update_entry(self, entry, *, data=_UNDEF, options=_UNDEF):
"""Update a config entry.""" """Update a config entry."""
@ -597,14 +661,7 @@ class ConfigEntries:
self._entries.append(entry) self._entries.append(entry)
self._async_schedule_save() self._async_schedule_save()
# Setup entry await self.async_setup(entry.entry_id)
if entry.domain in self.hass.config.components:
# Component already set up, just need to call setup_entry
await entry.async_setup(self.hass)
else:
# Setting up component will also load the entries
await async_setup_component(
self.hass, entry.domain, self._hass_config)
result['result'] = entry result['result'] = entry
return result return result

View File

@ -407,7 +407,7 @@ async def test_saving_and_loading(hass):
# Now load written data in new config manager # Now load written data in new config manager
manager = config_entries.ConfigEntries(hass, {}) manager = config_entries.ConfigEntries(hass, {})
await manager.async_load() await manager.async_initialize()
# Ensure same order # Ensure same order
for orig, loaded in zip(hass.config_entries.async_entries(), for orig, loaded in zip(hass.config_entries.async_entries(),
@ -518,7 +518,7 @@ async def test_loading_default_config(hass):
manager = config_entries.ConfigEntries(hass, {}) manager = config_entries.ConfigEntries(hass, {})
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError): with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
await manager.async_load() await manager.async_initialize()
assert len(manager.async_entries()) == 0 assert len(manager.async_entries()) == 0
@ -650,3 +650,219 @@ async def test_entry_options(hass, manager):
assert entry.options == { assert entry.options == {
'second': True 'second': True
} }
async def test_entry_setup_succeed(hass, manager):
"""Test that we can setup an entry."""
entry = MockConfigEntry(
domain='comp',
state=config_entries.ENTRY_STATE_NOT_LOADED
)
entry.add_to_hass(hass)
mock_setup = MagicMock(return_value=mock_coro(True))
mock_setup_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_setup=mock_setup,
async_setup_entry=mock_setup_entry
))
assert await manager.async_setup(entry.entry_id)
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
@pytest.mark.parametrize('state', (
config_entries.ENTRY_STATE_LOADED,
config_entries.ENTRY_STATE_SETUP_ERROR,
config_entries.ENTRY_STATE_MIGRATION_ERROR,
config_entries.ENTRY_STATE_SETUP_RETRY,
config_entries.ENTRY_STATE_FAILED_UNLOAD,
))
async def test_entry_setup_invalid_state(hass, manager, state):
"""Test that we cannot setup an entry with invalid state."""
entry = MockConfigEntry(
domain='comp',
state=state
)
entry.add_to_hass(hass)
mock_setup = MagicMock(return_value=mock_coro(True))
mock_setup_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_setup=mock_setup,
async_setup_entry=mock_setup_entry
))
with pytest.raises(config_entries.OperationNotAllowed):
assert await manager.async_setup(entry.entry_id)
assert len(mock_setup.mock_calls) == 0
assert len(mock_setup_entry.mock_calls) == 0
assert entry.state == state
async def test_entry_unload_succeed(hass, manager):
"""Test that we can unload an entry."""
entry = MockConfigEntry(
domain='comp',
state=config_entries.ENTRY_STATE_LOADED
)
entry.add_to_hass(hass)
async_unload_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_unload_entry=async_unload_entry
))
assert await manager.async_unload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED
@pytest.mark.parametrize('state', (
config_entries.ENTRY_STATE_NOT_LOADED,
config_entries.ENTRY_STATE_SETUP_ERROR,
config_entries.ENTRY_STATE_SETUP_RETRY,
))
async def test_entry_unload_failed_to_load(hass, manager, state):
"""Test that we can unload an entry."""
entry = MockConfigEntry(
domain='comp',
state=state,
)
entry.add_to_hass(hass)
async_unload_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_unload_entry=async_unload_entry
))
assert await manager.async_unload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 0
assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED
@pytest.mark.parametrize('state', (
config_entries.ENTRY_STATE_MIGRATION_ERROR,
config_entries.ENTRY_STATE_FAILED_UNLOAD,
))
async def test_entry_unload_invalid_state(hass, manager, state):
"""Test that we cannot unload an entry with invalid state."""
entry = MockConfigEntry(
domain='comp',
state=state
)
entry.add_to_hass(hass)
async_unload_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_unload_entry=async_unload_entry
))
with pytest.raises(config_entries.OperationNotAllowed):
assert await manager.async_unload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 0
assert entry.state == state
async def test_entry_reload_succeed(hass, manager):
"""Test that we can reload an entry."""
entry = MockConfigEntry(
domain='comp',
state=config_entries.ENTRY_STATE_LOADED
)
entry.add_to_hass(hass)
async_setup = MagicMock(return_value=mock_coro(True))
async_setup_entry = MagicMock(return_value=mock_coro(True))
async_unload_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_setup=async_setup,
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry
))
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert len(async_setup.mock_calls) == 1
assert len(async_setup_entry.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
@pytest.mark.parametrize('state', (
config_entries.ENTRY_STATE_NOT_LOADED,
config_entries.ENTRY_STATE_SETUP_ERROR,
config_entries.ENTRY_STATE_SETUP_RETRY,
))
async def test_entry_reload_not_loaded(hass, manager, state):
"""Test that we can reload an entry."""
entry = MockConfigEntry(
domain='comp',
state=state
)
entry.add_to_hass(hass)
async_setup = MagicMock(return_value=mock_coro(True))
async_setup_entry = MagicMock(return_value=mock_coro(True))
async_unload_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_setup=async_setup,
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry
))
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 0
assert len(async_setup.mock_calls) == 1
assert len(async_setup_entry.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
@pytest.mark.parametrize('state', (
config_entries.ENTRY_STATE_MIGRATION_ERROR,
config_entries.ENTRY_STATE_FAILED_UNLOAD,
))
async def test_entry_reload_error(hass, manager, state):
"""Test that we can reload an entry."""
entry = MockConfigEntry(
domain='comp',
state=state
)
entry.add_to_hass(hass)
async_setup = MagicMock(return_value=mock_coro(True))
async_setup_entry = MagicMock(return_value=mock_coro(True))
async_unload_entry = MagicMock(return_value=mock_coro(True))
loader.set_component(hass, 'comp', MockModule(
'comp',
async_setup=async_setup,
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry
))
with pytest.raises(config_entries.OperationNotAllowed):
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 0
assert len(async_setup.mock_calls) == 0
assert len(async_setup_entry.mock_calls) == 0
assert entry.state == state