From b3a47722f0d5c6693e29af415a9fad4d3d956566 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 16 Feb 2018 14:07:38 -0800 Subject: [PATCH] Initial support for Config Entries (#12079) * Introduce Config Entries * Rebase fail * Address comments * Address more comments * RequestDataValidator moved --- homeassistant/bootstrap.py | 7 +- homeassistant/components/config/__init__.py | 10 +- .../components/config/config_entries.py | 182 ++++++ .../components/config_entry_example.py | 102 ++++ homeassistant/config_entries.py | 516 ++++++++++++++++++ homeassistant/setup.py | 10 +- requirements_all.txt | 3 + requirements_test_all.txt | 3 + script/gen_requirements_all.py | 1 + tests/common.py | 43 +- .../binary_sensor/test_command_line.py | 11 - .../components/config/test_config_entries.py | 317 +++++++++++ tests/components/sensor/test_command_line.py | 11 - tests/components/test_config_entry_example.py | 38 ++ tests/test_config_entries.py | 397 ++++++++++++++ 15 files changed, 1622 insertions(+), 29 deletions(-) create mode 100644 homeassistant/components/config/config_entries.py create mode 100644 homeassistant/components/config_entry_example.py create mode 100644 homeassistant/config_entries.py create mode 100644 tests/components/config/test_config_entries.py create mode 100644 tests/components/test_config_entry_example.py create mode 100644 tests/test_config_entries.py diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index c5b01916d8c..4971cbccc9c 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -12,7 +12,8 @@ from typing import Any, Optional, Dict import voluptuous as vol from homeassistant import ( - core, config as conf_util, loader, components as core_components) + core, config as conf_util, config_entries, loader, + components as core_components) from homeassistant.components import persistent_notification from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE from homeassistant.setup import async_setup_component @@ -123,9 +124,13 @@ def async_from_config_dict(config: Dict[str, Any], new_config[key] = value or {} config = new_config + hass.config_entries = config_entries.ConfigEntries(hass, config) + yield from hass.config_entries.async_load() + # Filter out the repeating and common config section [homeassistant] components = set(key.split(' ')[0] for key in config.keys() if key != core.DOMAIN) + components.update(hass.config_entries.async_domains()) # setup components # pylint: disable=not-an-iterable diff --git a/homeassistant/components/config/__init__.py b/homeassistant/components/config/__init__.py index c45e3561c47..7f2041249e0 100644 --- a/homeassistant/components/config/__init__.py +++ b/homeassistant/components/config/__init__.py @@ -14,15 +14,23 @@ from homeassistant.util.yaml import load_yaml, dump DOMAIN = 'config' DEPENDENCIES = ['http'] SECTIONS = ('core', 'customize', 'group', 'hassbian', 'automation', 'script') -ON_DEMAND = ('zwave') +ON_DEMAND = ('zwave',) +FEATURE_FLAGS = ('hidden_entries',) @asyncio.coroutine def async_setup(hass, config): """Set up the config component.""" + global SECTIONS + yield from hass.components.frontend.async_register_built_in_panel( 'config', 'config', 'mdi:settings') + # Temporary way of allowing people to opt-in for unreleased config sections + for key, value in config.get(DOMAIN, {}).items(): + if key in FEATURE_FLAGS and value: + SECTIONS += (key,) + @asyncio.coroutine def setup_panel(panel_name): """Set up a panel.""" diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py new file mode 100644 index 00000000000..d33e97b9e88 --- /dev/null +++ b/homeassistant/components/config/config_entries.py @@ -0,0 +1,182 @@ +"""Http views to control the config manager.""" +import asyncio + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.data_validator import RequestDataValidator + + +REQUIREMENTS = ['voluptuous-serialize==0.1'] + + +@asyncio.coroutine +def async_setup(hass): + """Enable the Home Assistant views.""" + hass.http.register_view(ConfigManagerEntryIndexView) + hass.http.register_view(ConfigManagerEntryResourceView) + hass.http.register_view(ConfigManagerFlowIndexView) + hass.http.register_view(ConfigManagerFlowResourceView) + hass.http.register_view(ConfigManagerAvailableFlowView) + return True + + +def _prepare_json(result): + """Convert result for JSON.""" + if result['type'] != config_entries.RESULT_TYPE_FORM: + return result + + import voluptuous_serialize + + data = result.copy() + + schema = data['data_schema'] + if schema is None: + data['data_schema'] = [] + else: + data['data_schema'] = voluptuous_serialize.convert(schema) + + return data + + +class ConfigManagerEntryIndexView(HomeAssistantView): + """View to get available config entries.""" + + url = '/api/config/config_entries/entry' + name = 'api:config:config_entries:entry' + + @asyncio.coroutine + def get(self, request): + """List flows in progress.""" + hass = request.app['hass'] + return self.json([{ + 'entry_id': entry.entry_id, + 'domain': entry.domain, + 'title': entry.title, + 'source': entry.source, + 'state': entry.state, + } for entry in hass.config_entries.async_entries()]) + + +class ConfigManagerEntryResourceView(HomeAssistantView): + """View to interact with a config entry.""" + + url = '/api/config/config_entries/entry/{entry_id}' + name = 'api:config:config_entries:entry:resource' + + @asyncio.coroutine + def delete(self, request, entry_id): + """Delete a config entry.""" + hass = request.app['hass'] + + try: + result = yield from hass.config_entries.async_remove(entry_id) + except config_entries.UnknownEntry: + return self.json_message('Invalid entry specified', 404) + + return self.json(result) + + +class ConfigManagerFlowIndexView(HomeAssistantView): + """View to create config flows.""" + + url = '/api/config/config_entries/flow' + name = 'api:config:config_entries:flow' + + @asyncio.coroutine + def get(self, request): + """List flows that are in progress but not started by a user. + + Example of a non-user initiated flow is a discovered Hue hub that + requires user interaction to finish setup. + """ + hass = request.app['hass'] + + return self.json([ + flow for flow in hass.config_entries.flow.async_progress() + if flow['source'] != config_entries.SOURCE_USER]) + + @asyncio.coroutine + @RequestDataValidator(vol.Schema({ + vol.Required('domain'): str, + })) + def post(self, request, data): + """Handle a POST request.""" + hass = request.app['hass'] + + try: + result = yield from hass.config_entries.flow.async_init( + data['domain']) + except config_entries.UnknownHandler: + return self.json_message('Invalid handler specified', 404) + except config_entries.UnknownStep: + return self.json_message('Handler does not support init', 400) + + result = _prepare_json(result) + + return self.json(result) + + +class ConfigManagerFlowResourceView(HomeAssistantView): + """View to interact with the flow manager.""" + + url = '/api/config/config_entries/flow/{flow_id}' + name = 'api:config:config_entries:flow:resource' + + @asyncio.coroutine + def get(self, request, flow_id): + """Get the current state of a flow.""" + hass = request.app['hass'] + + try: + result = yield from hass.config_entries.flow.async_configure( + flow_id) + except config_entries.UnknownFlow: + return self.json_message('Invalid flow specified', 404) + + result = _prepare_json(result) + + return self.json(result) + + @asyncio.coroutine + @RequestDataValidator(vol.Schema(dict), allow_empty=True) + def post(self, request, flow_id, data): + """Handle a POST request.""" + hass = request.app['hass'] + + try: + result = yield from hass.config_entries.flow.async_configure( + flow_id, data) + except config_entries.UnknownFlow: + return self.json_message('Invalid flow specified', 404) + except vol.Invalid: + return self.json_message('User input malformed', 400) + + result = _prepare_json(result) + + return self.json(result) + + @asyncio.coroutine + def delete(self, request, flow_id): + """Cancel a flow in progress.""" + hass = request.app['hass'] + + try: + hass.config_entries.async_abort(flow_id) + except config_entries.UnknownFlow: + return self.json_message('Invalid flow specified', 404) + + return self.json_message('Flow aborted') + + +class ConfigManagerAvailableFlowView(HomeAssistantView): + """View to query available flows.""" + + url = '/api/config/config_entries/flow_handlers' + name = 'api:config:config_entries:flow_handlers' + + @asyncio.coroutine + def get(self, request): + """List available flow handlers.""" + return self.json(config_entries.FLOWS) diff --git a/homeassistant/components/config_entry_example.py b/homeassistant/components/config_entry_example.py new file mode 100644 index 00000000000..2d5ea728ff3 --- /dev/null +++ b/homeassistant/components/config_entry_example.py @@ -0,0 +1,102 @@ +"""Example component to show how config entries work.""" + +import asyncio + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.const import ATTR_FRIENDLY_NAME +from homeassistant.util import slugify + + +DOMAIN = 'config_entry_example' + + +@asyncio.coroutine +def async_setup(hass, config): + """Setup for our example component.""" + return True + + +@asyncio.coroutine +def async_setup_entry(hass, entry): + """Initialize an entry.""" + entity_id = '{}.{}'.format(DOMAIN, entry.data['object_id']) + hass.states.async_set(entity_id, 'loaded', { + ATTR_FRIENDLY_NAME: entry.data['name'] + }) + + # Indicate setup was successful. + return True + + +@asyncio.coroutine +def async_unload_entry(hass, entry): + """Unload an entry.""" + entity_id = '{}.{}'.format(DOMAIN, entry.data['object_id']) + hass.states.async_remove(entity_id) + + # Indicate unload was successful. + return True + + +@config_entries.HANDLERS.register(DOMAIN) +class ExampleConfigFlow(config_entries.ConfigFlowHandler): + """Handle an example configuration flow.""" + + VERSION = 1 + + def __init__(self): + """Initialize a Hue config handler.""" + self.object_id = None + + @asyncio.coroutine + def async_step_init(self, user_input=None): + """Start config flow.""" + errors = None + if user_input is not None: + object_id = user_input['object_id'] + + if object_id != '' and object_id == slugify(object_id): + self.object_id = user_input['object_id'] + return (yield from self.async_step_name()) + + errors = { + 'object_id': 'Invalid object id.' + } + + return self.async_show_form( + title='Pick object id', + step_id='init', + description="Please enter an object_id for the test entity.", + data_schema=vol.Schema({ + 'object_id': str + }), + errors=errors + ) + + @asyncio.coroutine + def async_step_name(self, user_input=None): + """Ask user to enter the name.""" + errors = None + if user_input is not None: + name = user_input['name'] + + if name != '': + return self.async_create_entry( + title=name, + data={ + 'name': name, + 'object_id': self.object_id, + } + ) + + return self.async_show_form( + title='Name of the entity', + step_id='name', + description="Please enter a name for the test entity.", + data_schema=vol.Schema({ + 'name': str + }), + errors=errors + ) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py new file mode 100644 index 00000000000..7b5d23d284f --- /dev/null +++ b/homeassistant/config_entries.py @@ -0,0 +1,516 @@ +"""The Config Manager is responsible for managing configuration for components. + +The Config Manager allows for creating config entries to be consumed by +components. Each entry is created via a Config Flow Handler, as defined by each +component. + +During startup, Home Assistant will setup the entries during the normal setup +of a component. It will first call the normal setup and then call the method +`async_setup_entry(hass, entry)` for each entry. The same method is called when +Home Assistant is running while a config entry is created. + +## Config Flows + +A component needs to define a Config Handler to allow the user to create config +entries for that component. A config flow will manage the creation of entries +from user input, discovery or other sources (like hassio). + +When a config flow is started for a domain, the handler will be instantiated +and receives a unique id. The instance of this handler will be reused for every +interaction of the user with this flow. This makes it possible to store +instance variables on the handler. + +Before instantiating the handler, Home Assistant will make sure to load all +dependencies and install the requirements of the component. + +At a minimum, each config flow will have to define a version number and the +'init' step. + + @config_entries.HANDLERS.register(DOMAIN) + class ExampleConfigFlow(config_entries.ConfigFlowHandler): + + VERSION = 1 + + async def async_step_init(self, user_input=None): + … + +The 'init' step is the first step of a flow and is called when a user +starts a new flow. Each step has three different possible results: "Show Form", +"Abort" and "Create Entry". + +### Show Form + +This will show a form to the user to fill in. You define the current step, +a title, a description and the schema of the data that needs to be returned. + + async def async_step_init(self, user_input=None): + # Use OrderedDict to guarantee order of the form shown to the user + data_schema = OrderedDict() + data_schema[vol.Required('username')] = str + data_schema[vol.Required('password')] = str + + return self.async_show_form( + step_id='init', + title='Account Info', + data_schema=vol.Schema(data_schema) + ) + +After the user has filled in the form, the step method will be called again and +the user input is passed in. If the validation of the user input fails , you +can return a dictionary with errors. Each key in the dictionary refers to a +field name that contains the error. Use the key 'base' if you want to show a +generic error. + + async def async_step_init(self, user_input=None): + errors = None + if user_input is not None: + # Validate user input + if valid: + return self.create_entry(…) + + errors['base'] = 'Unable to reach authentication server.' + + return self.async_show_form(…) + +If the user input passes validation, you can again return one of the three +return values. If you want to navigate the user to the next step, return the +return value of that step: + + return (await self.async_step_account()) + +### Abort + +When the result is "Abort", a message will be shown to the user and the +configuration flow is finished. + + return self.async_abort( + reason='This device is not supported by Home Assistant.' + ) + +### Create Entry + +When the result is "Create Entry", an entry will be created and stored in Home +Assistant, a success message is shown to the user and the flow is finished. + +## Initializing a config flow from an external source + +You might want to initialize a config flow programmatically. For example, if +we discover a device on the network that requires user interaction to finish +setup. To do so, pass a source parameter and optional user input to the init +step: + + await hass.config_entries.flow.async_init( + 'hue', source='discovery', data=discovery_info) + +The config flow handler will need to add a step to support the source. The step +should follow the same return values as a normal step. + + async def async_step_discovery(info): + +If the result of the step is to show a form, the user will be able to continue +the flow from the config panel. +""" +import asyncio +import logging +import os +import uuid + +from .core import callback +from .exceptions import HomeAssistantError +from .setup import async_setup_component, async_process_deps_reqs +from .util.json import load_json, save_json +from .util.decorator import Registry + + +_LOGGER = logging.getLogger(__name__) +HANDLERS = Registry() +# Components that have config flows. In future we will auto-generate this list. +FLOWS = [ + 'config_entry_example' +] + +SOURCE_USER = 'user' +SOURCE_DISCOVERY = 'discovery' + +PATH_CONFIG = '.config_entries.json' + +SAVE_DELAY = 1 + +RESULT_TYPE_FORM = 'form' +RESULT_TYPE_CREATE_ENTRY = 'create_entry' +RESULT_TYPE_ABORT = 'abort' + +ENTRY_STATE_LOADED = 'loaded' +ENTRY_STATE_SETUP_ERROR = 'setup_error' +ENTRY_STATE_NOT_LOADED = 'not_loaded' +ENTRY_STATE_FAILED_UNLOAD = 'failed_unload' + + +class ConfigEntry: + """Hold a configuration entry.""" + + __slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source', + 'state') + + def __init__(self, version, domain, title, data, source, entry_id=None, + state=ENTRY_STATE_NOT_LOADED): + """Initialize a config entry.""" + # Unique id of the config entry + self.entry_id = entry_id or uuid.uuid4().hex + + # Version of the configuration. + self.version = version + + # Domain the configuration belongs to + self.domain = domain + + # Title of the configuration + self.title = title + + # Config data + self.data = data + + # Source of the configuration (user, discovery, cloud) + self.source = source + + # State of the entry (LOADED, NOT_LOADED) + self.state = state + + @asyncio.coroutine + def async_setup(self, hass, *, component=None): + """Set up an entry.""" + if component is None: + component = getattr(hass.components, self.domain) + + try: + result = yield from component.async_setup_entry(hass, self) + + if not isinstance(result, bool): + _LOGGER.error('%s.async_config_entry did not return boolean', + self.domain) + result = False + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Error setting up entry %s for %s', + self.title, self.domain) + result = False + + if result: + self.state = ENTRY_STATE_LOADED + else: + self.state = ENTRY_STATE_SETUP_ERROR + + @asyncio.coroutine + def async_unload(self, hass): + """Unload an entry. + + Returns if unload is possible and was successful. + """ + component = getattr(hass.components, self.domain) + + supports_unload = hasattr(component, 'async_unload_entry') + + if not supports_unload: + return False + + try: + result = yield from component.async_unload_entry(hass, self) + + if not isinstance(result, bool): + _LOGGER.error('%s.async_unload_entry did not return boolean', + self.domain) + result = False + + return result + except Exception: # pylint: disable=broad-except + _LOGGER.exception('Error unloading entry %s for %s', + self.title, self.domain) + self.state = ENTRY_STATE_FAILED_UNLOAD + return False + + def as_dict(self): + """Return dictionary version of this entry.""" + return { + 'entry_id': self.entry_id, + 'version': self.version, + 'domain': self.domain, + 'title': self.title, + 'data': self.data, + 'source': self.source, + } + + +class ConfigError(HomeAssistantError): + """Error while configuring an account.""" + + +class UnknownEntry(ConfigError): + """Unknown entry specified.""" + + +class UnknownHandler(ConfigError): + """Unknown handler specified.""" + + +class UnknownFlow(ConfigError): + """Uknown flow specified.""" + + +class UnknownStep(ConfigError): + """Unknown step specified.""" + + +class ConfigEntries: + """Manage the configuration entries. + + An instance of this object is available via `hass.config_entries`. + """ + + def __init__(self, hass, hass_config): + """Initialize the entry manager.""" + self.hass = hass + self.flow = FlowManager(hass, hass_config, self._async_add_entry) + self._hass_config = hass_config + self._entries = None + self._sched_save = None + + @callback + def async_domains(self): + """Return domains for which we have entries.""" + seen = set() + result = [] + + for entry in self._entries: + if entry.domain not in seen: + seen.add(entry.domain) + result.append(entry.domain) + + return result + + @callback + def async_entries(self, domain=None): + """Return all entries or entries for a specific domain.""" + if domain is None: + return list(self._entries) + return [entry for entry in self._entries if entry.domain == domain] + + @asyncio.coroutine + def async_remove(self, entry_id): + """Remove an entry.""" + found = None + for index, entry in enumerate(self._entries): + if entry.entry_id == entry_id: + found = index + break + + if found is None: + raise UnknownEntry + + entry = self._entries.pop(found) + self._async_schedule_save() + + unloaded = yield from entry.async_unload(self.hass) + + return { + 'require_restart': not unloaded + } + + @asyncio.coroutine + def async_load(self): + """Load the config.""" + path = self.hass.config.path(PATH_CONFIG) + if not os.path.isfile(path): + self._entries = [] + return + + entries = yield from self.hass.async_add_job(load_json, path) + self._entries = [ConfigEntry(**entry) for entry in entries] + + @asyncio.coroutine + def _async_add_entry(self, entry): + """Add an entry.""" + self._entries.append(entry) + self._async_schedule_save() + + # Setup entry + if entry.domain in self.hass.config.components: + # Component already set up, just need to call setup_entry + yield from entry.async_setup(self.hass) + else: + # Setting up component will also load the entries + yield from async_setup_component( + self.hass, entry.domain, self._hass_config) + + @callback + def _async_schedule_save(self): + """Schedule saving the entity registry.""" + if self._sched_save is not None: + self._sched_save.cancel() + + self._sched_save = self.hass.loop.call_later( + SAVE_DELAY, self.hass.async_add_job, self._async_save + ) + + @asyncio.coroutine + def _async_save(self): + """Save the entity registry to a file.""" + self._sched_save = None + data = [entry.as_dict() for entry in self._entries] + + yield from self.hass.async_add_job( + save_json, self.hass.config.path(PATH_CONFIG), data) + + +class FlowManager: + """Manage all the config flows that are in progress.""" + + def __init__(self, hass, hass_config, async_add_entry): + """Initialize the flow manager.""" + self.hass = hass + self._hass_config = hass_config + self._progress = {} + self._async_add_entry = async_add_entry + + @callback + def async_progress(self): + """Return the flows in progress.""" + return [{ + 'flow_id': flow.flow_id, + 'domain': flow.domain, + 'source': flow.source, + } for flow in self._progress.values()] + + @asyncio.coroutine + def async_init(self, domain, *, source=SOURCE_USER, data=None): + """Start a configuration flow.""" + handler = HANDLERS.get(domain) + + if handler is None: + # This will load the component and thus register the handler + component = getattr(self.hass.components, domain) + handler = HANDLERS.get(domain) + + if handler is None: + raise self.hass.helpers.UnknownHandler + + # Make sure requirements and dependencies of component are resolved + yield from async_process_deps_reqs( + self.hass, self._hass_config, domain, component) + + flow_id = uuid.uuid4().hex + flow = self._progress[flow_id] = handler() + flow.hass = self.hass + flow.domain = domain + flow.flow_id = flow_id + flow.source = source + + if source == SOURCE_USER: + step = 'init' + else: + step = source + + return (yield from self._async_handle_step(flow, step, data)) + + @asyncio.coroutine + def async_configure(self, flow_id, user_input=None): + """Start or continue a configuration flow.""" + flow = self._progress.get(flow_id) + + if flow is None: + raise UnknownFlow + + step_id, data_schema = flow.cur_step + + if data_schema is not None and user_input is not None: + user_input = data_schema(user_input) + + return (yield from self._async_handle_step( + flow, step_id, user_input)) + + @callback + def async_abort(self, flow_id): + """Abort a flow.""" + if self._progress.pop(flow_id, None) is None: + raise UnknownFlow + + @asyncio.coroutine + def _async_handle_step(self, flow, step_id, user_input): + """Handle a step of a flow.""" + method = "async_step_{}".format(step_id) + + if not hasattr(flow, method): + self._progress.pop(flow.flow_id) + raise UnknownStep("Handler {} doesn't support step {}".format( + flow.__class__.__name__, step_id)) + + result = yield from getattr(flow, method)(user_input) + + if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY, + RESULT_TYPE_ABORT): + raise ValueError( + 'Handler returned incorrect type: {}'.format(result['type'])) + + if result['type'] == RESULT_TYPE_FORM: + flow.cur_step = (result.pop('step_id'), result['data_schema']) + return result + + # Abort and Success results both finish the flow + self._progress.pop(flow.flow_id) + + if result['type'] == RESULT_TYPE_ABORT: + return result + + entry = ConfigEntry( + version=flow.VERSION, + domain=flow.domain, + title=result['title'], + data=result.pop('data'), + source=flow.source + ) + yield from self._async_add_entry(entry) + return result + + +class ConfigFlowHandler: + """Handle the configuration flow of a component.""" + + # Set by flow manager + flow_id = None + hass = None + source = SOURCE_USER + cur_step = None + + # Set by dev + # VERSION + + @callback + def async_show_form(self, *, title, step_id, description=None, + data_schema=None, errors=None): + """Return the definition of a form to gather user input.""" + return { + 'type': RESULT_TYPE_FORM, + 'flow_id': self.flow_id, + 'title': title, + 'step_id': step_id, + 'description': description, + 'data_schema': data_schema, + 'errors': errors, + } + + @callback + def async_create_entry(self, *, title, data): + """Finish config flow and create a config entry.""" + return { + 'type': RESULT_TYPE_CREATE_ENTRY, + 'flow_id': self.flow_id, + 'title': title, + 'data': data, + } + + @callback + def async_abort(self, *, reason): + """Abort the config flow.""" + return { + 'type': RESULT_TYPE_ABORT, + 'flow_id': self.flow_id, + 'reason': reason + } diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 2c69fdefeee..5a8681e82fd 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -123,7 +123,7 @@ def _async_setup_component(hass: core.HomeAssistant, return False try: - yield from _process_deps_reqs(hass, config, domain, component) + yield from async_process_deps_reqs(hass, config, domain, component) except HomeAssistantError as err: log_error(str(err)) return False @@ -165,6 +165,9 @@ def _async_setup_component(hass: core.HomeAssistant, loader.set_component(domain, None) return False + for entry in hass.config_entries.async_entries(domain): + yield from entry.async_setup(hass, component=component) + hass.config.components.add(component.DOMAIN) # Cleanup @@ -206,7 +209,8 @@ def async_prepare_setup_platform(hass: core.HomeAssistant, config, domain: str, return platform try: - yield from _process_deps_reqs(hass, config, platform_path, platform) + yield from async_process_deps_reqs( + hass, config, platform_path, platform) except HomeAssistantError as err: log_error(str(err)) return None @@ -215,7 +219,7 @@ def async_prepare_setup_platform(hass: core.HomeAssistant, config, domain: str, @asyncio.coroutine -def _process_deps_reqs(hass, config, name, module): +def async_process_deps_reqs(hass, config, name, module): """Process all dependencies and requirements for a module. Module is a Python module of either a component or platform. diff --git a/requirements_all.txt b/requirements_all.txt index a77021d2297..24156b517a8 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1198,6 +1198,9 @@ uvcclient==0.10.1 # homeassistant.components.climate.venstar venstarcolortouch==0.6 +# homeassistant.components.config.config_entries +voluptuous-serialize==0.1 + # homeassistant.components.volvooncall volvooncall==0.4.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index ecde5a5fc9e..4155fea78be 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -175,6 +175,9 @@ statsd==3.2.1 # homeassistant.components.camera.uvc uvcclient==0.10.1 +# homeassistant.components.config.config_entries +voluptuous-serialize==0.1 + # homeassistant.components.vultr vultr==0.1.2 diff --git a/script/gen_requirements_all.py b/script/gen_requirements_all.py index 9c510d8339e..42acee96206 100755 --- a/script/gen_requirements_all.py +++ b/script/gen_requirements_all.py @@ -82,6 +82,7 @@ TEST_REQUIREMENTS = ( 'sqlalchemy', 'statsd', 'uvcclient', + 'voluptuous-serialize', 'warrant', 'yahoo-finance', 'pythonwhois', diff --git a/tests/common.py b/tests/common.py index 1b79d15b319..6fee7b1bec0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,7 +9,7 @@ import logging import threading from contextlib import contextmanager -from homeassistant import core as ha, loader +from homeassistant import core as ha, loader, config_entries from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config from homeassistant.helpers import ( @@ -109,6 +109,9 @@ def get_test_home_assistant(): def async_test_home_assistant(loop): """Return a Home Assistant object pointing at test config dir.""" hass = ha.HomeAssistant(loop) + hass.config_entries = config_entries.ConfigEntries(hass, {}) + hass.config_entries._entries = [] + hass.config.async_load = Mock() INSTANCES.append(hass) orig_async_add_job = hass.async_add_job @@ -305,7 +308,8 @@ class MockModule(object): # pylint: disable=invalid-name def __init__(self, domain=None, dependencies=None, setup=None, requirements=None, config_schema=None, platform_schema=None, - async_setup=None): + async_setup=None, async_setup_entry=None, + async_unload_entry=None): """Initialize the mock module.""" self.DOMAIN = domain self.DEPENDENCIES = dependencies or [] @@ -327,6 +331,12 @@ class MockModule(object): if setup is None and async_setup is None: self.async_setup = mock_coro_func(True) + if async_setup_entry is not None: + self.async_setup_entry = async_setup_entry + + if async_unload_entry is not None: + self.async_unload_entry = async_unload_entry + class MockPlatform(object): """Provide a fake platform.""" @@ -402,6 +412,35 @@ class MockToggleDevice(entity.ToggleEntity): return None +class MockConfigEntry(config_entries.ConfigEntry): + """Helper for creating config entries that adds some defaults.""" + + def __init__(self, *, domain='test', data=None, version=0, entry_id=None, + source=config_entries.SOURCE_USER, title='Mock Title', + state=None): + """Initialize a mock config entry.""" + kwargs = { + 'entry_id': entry_id or 'mock-id', + 'domain': domain, + 'data': data or {}, + 'version': version, + 'title': title + } + if source is not None: + kwargs['source'] = source + if state is not None: + kwargs['state'] = state + super().__init__(**kwargs) + + def add_to_hass(self, hass): + """Test helper to add entry to hass.""" + hass.config_entries._entries.append(self) + + def add_to_manager(self, manager): + """Test helper to add entry to entry manager.""" + manager._entries.append(self) + + def patch_yaml_files(files_dict, endswith=True): """Patch load_yaml with a dictionary of yaml files.""" # match using endswith, start search with longest string diff --git a/tests/components/binary_sensor/test_command_line.py b/tests/components/binary_sensor/test_command_line.py index f35e6f08452..d01b62e4c12 100644 --- a/tests/components/binary_sensor/test_command_line.py +++ b/tests/components/binary_sensor/test_command_line.py @@ -3,7 +3,6 @@ import unittest from homeassistant.const import (STATE_ON, STATE_OFF) from homeassistant.components.binary_sensor import command_line -from homeassistant import setup from homeassistant.helpers import template from tests.common import get_test_home_assistant @@ -42,16 +41,6 @@ class TestCommandSensorBinarySensor(unittest.TestCase): self.assertEqual('Test', entity.name) self.assertEqual(STATE_ON, entity.state) - def test_setup_bad_config(self): - """Test the setup with a bad configuration.""" - config = {'name': 'test', - 'platform': 'not_command_line', - } - - self.assertFalse(setup.setup_component(self.hass, 'test', { - 'command_line': config, - })) - def test_template(self): """Test setting the state with a template.""" data = command_line.CommandSensorData(self.hass, 'echo 10') diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py new file mode 100644 index 00000000000..1551ba74319 --- /dev/null +++ b/tests/components/config/test_config_entries.py @@ -0,0 +1,317 @@ +"""Test config entries API.""" + +import asyncio +from collections import OrderedDict +from unittest.mock import patch + +import pytest +import voluptuous as vol + +from homeassistant import config_entries as core_ce +from homeassistant.config_entries import ConfigFlowHandler, HANDLERS +from homeassistant.setup import async_setup_component +from homeassistant.components.config import config_entries +from homeassistant.loader import set_component + +from tests.common import MockConfigEntry, MockModule, mock_coro_func + + +@pytest.fixture +def client(hass, test_client): + """Fixture that can interact with the config manager API.""" + hass.loop.run_until_complete(async_setup_component(hass, 'http', {})) + hass.loop.run_until_complete(config_entries.async_setup(hass)) + yield hass.loop.run_until_complete(test_client(hass.http.app)) + + +@asyncio.coroutine +def test_get_entries(hass, client): + """Test get entries.""" + MockConfigEntry( + domain='comp', + title='Test 1', + source='bla' + ).add_to_hass(hass) + MockConfigEntry( + domain='comp2', + title='Test 2', + source='bla2', + state=core_ce.ENTRY_STATE_LOADED, + ).add_to_hass(hass) + resp = yield from client.get('/api/config/config_entries/entry') + assert resp.status == 200 + data = yield from resp.json() + for entry in data: + entry.pop('entry_id') + assert data == [ + { + 'domain': 'comp', + 'title': 'Test 1', + 'source': 'bla', + 'state': 'not_loaded' + }, + { + 'domain': 'comp2', + 'title': 'Test 2', + 'source': 'bla2', + 'state': 'loaded', + }, + ] + + +@asyncio.coroutine +def test_remove_entry(hass, client): + """Test removing an entry via the API.""" + entry = MockConfigEntry(domain='demo') + entry.add_to_hass(hass) + resp = yield from client.delete( + '/api/config/config_entries/entry/{}'.format(entry.entry_id)) + assert resp.status == 200 + data = yield from resp.json() + assert data == { + 'require_restart': True + } + assert len(hass.config_entries.async_entries()) == 0 + + +@asyncio.coroutine +def test_available_flows(hass, client): + """Test querying the available flows.""" + with patch.object(core_ce, 'FLOWS', ['hello', 'world']): + resp = yield from client.get( + '/api/config/config_entries/flow_handlers') + assert resp.status == 200 + data = yield from resp.json() + assert data == ['hello', 'world'] + + +############################ +# FLOW MANAGER API TESTS # +############################ + + +@asyncio.coroutine +def test_initialize_flow(hass, client): + """Test we can initialize a flow.""" + class TestFlow(ConfigFlowHandler): + @asyncio.coroutine + def async_step_init(self, user_input=None): + schema = OrderedDict() + schema[vol.Required('username')] = str + schema[vol.Required('password')] = str + + return self.async_show_form( + title='test-title', + step_id='init', + description='test-description', + data_schema=schema, + errors={ + 'username': 'Should be unique.' + } + ) + + with patch.dict(HANDLERS, {'test': TestFlow}): + resp = yield from client.post('/api/config/config_entries/flow', + json={'domain': 'test'}) + + assert resp.status == 200 + data = yield from resp.json() + + data.pop('flow_id') + + assert data == { + 'type': 'form', + 'title': 'test-title', + 'description': 'test-description', + 'data_schema': [ + { + 'name': 'username', + 'required': True, + 'type': 'string' + }, + { + 'name': 'password', + 'required': True, + 'type': 'string' + } + ], + 'errors': { + 'username': 'Should be unique.' + } + } + + +@asyncio.coroutine +def test_abort(hass, client): + """Test a flow that aborts.""" + class TestFlow(ConfigFlowHandler): + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_abort(reason='bla') + + with patch.dict(HANDLERS, {'test': TestFlow}): + resp = yield from client.post('/api/config/config_entries/flow', + json={'domain': 'test'}) + + assert resp.status == 200 + data = yield from resp.json() + data.pop('flow_id') + assert data == { + 'reason': 'bla', + 'type': 'abort' + } + + +@asyncio.coroutine +def test_create_account(hass, client): + """Test a flow that creates an account.""" + set_component( + 'test', MockModule('test', async_setup_entry=mock_coro_func(True))) + + class TestFlow(ConfigFlowHandler): + VERSION = 1 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_create_entry( + title='Test Entry', + data={'secret': 'account_token'} + ) + + with patch.dict(HANDLERS, {'test': TestFlow}): + resp = yield from client.post('/api/config/config_entries/flow', + json={'domain': 'test'}) + + assert resp.status == 200 + data = yield from resp.json() + data.pop('flow_id') + assert data == { + 'title': 'Test Entry', + 'type': 'create_entry' + } + + +@asyncio.coroutine +def test_two_step_flow(hass, client): + """Test we can finish a two step flow.""" + set_component( + 'test', MockModule('test', async_setup_entry=mock_coro_func(True))) + + class TestFlow(ConfigFlowHandler): + VERSION = 1 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_show_form( + title='test-title', + step_id='account', + data_schema=vol.Schema({ + 'user_title': str + })) + + @asyncio.coroutine + def async_step_account(self, user_input=None): + return self.async_create_entry( + title=user_input['user_title'], + data={'secret': 'account_token'} + ) + + with patch.dict(HANDLERS, {'test': TestFlow}): + resp = yield from client.post('/api/config/config_entries/flow', + json={'domain': 'test'}) + assert resp.status == 200 + data = yield from resp.json() + flow_id = data.pop('flow_id') + assert data == { + 'type': 'form', + 'title': 'test-title', + 'description': None, + 'data_schema': [ + { + 'name': 'user_title', + 'type': 'string' + } + ], + 'errors': None + } + + with patch.dict(HANDLERS, {'test': TestFlow}): + resp = yield from client.post( + '/api/config/config_entries/flow/{}'.format(flow_id), + json={'user_title': 'user-title'}) + assert resp.status == 200 + data = yield from resp.json() + data.pop('flow_id') + assert data == { + 'type': 'create_entry', + 'title': 'user-title', + } + + +@asyncio.coroutine +def test_get_progress_index(hass, client): + """Test querying for the flows that are in progress.""" + class TestFlow(ConfigFlowHandler): + VERSION = 5 + + @asyncio.coroutine + def async_step_hassio(self, info): + return (yield from self.async_step_account()) + + @asyncio.coroutine + def async_step_account(self, user_input=None): + return self.async_show_form( + step_id='account', + title='Finish setup' + ) + + with patch.dict(HANDLERS, {'test': TestFlow}): + form = yield from hass.config_entries.flow.async_init( + 'test', source='hassio') + + resp = yield from client.get('/api/config/config_entries/flow') + assert resp.status == 200 + data = yield from resp.json() + assert data == [ + { + 'flow_id': form['flow_id'], + 'domain': 'test', + 'source': 'hassio' + } + ] + + +@asyncio.coroutine +def test_get_progress_flow(hass, client): + """Test we can query the API for same result as we get from init a flow.""" + class TestFlow(ConfigFlowHandler): + @asyncio.coroutine + def async_step_init(self, user_input=None): + schema = OrderedDict() + schema[vol.Required('username')] = str + schema[vol.Required('password')] = str + + return self.async_show_form( + title='test-title', + step_id='init', + description='test-description', + data_schema=schema, + errors={ + 'username': 'Should be unique.' + } + ) + + with patch.dict(HANDLERS, {'test': TestFlow}): + resp = yield from client.post('/api/config/config_entries/flow', + json={'domain': 'test'}) + + assert resp.status == 200 + data = yield from resp.json() + + resp2 = yield from client.get( + '/api/config/config_entries/flow/{}'.format(data['flow_id'])) + + assert resp2.status == 200 + data2 = yield from resp2.json() + + assert data == data2 diff --git a/tests/components/sensor/test_command_line.py b/tests/components/sensor/test_command_line.py index 6eb97b41e11..bc073a04c47 100644 --- a/tests/components/sensor/test_command_line.py +++ b/tests/components/sensor/test_command_line.py @@ -3,7 +3,6 @@ import unittest from homeassistant.helpers.template import Template from homeassistant.components.sensor import command_line -from homeassistant import setup from tests.common import get_test_home_assistant @@ -40,16 +39,6 @@ class TestCommandSensorSensor(unittest.TestCase): self.assertEqual('in', entity.unit_of_measurement) self.assertEqual('5', entity.state) - def test_setup_bad_config(self): - """Test setup with a bad configuration.""" - config = {'name': 'test', - 'platform': 'not_command_line', - } - - self.assertFalse(setup.setup_component(self.hass, 'test', { - 'command_line': config, - })) - def test_template(self): """Test command sensor with template.""" data = command_line.CommandSensorData(self.hass, 'echo 50') diff --git a/tests/components/test_config_entry_example.py b/tests/components/test_config_entry_example.py new file mode 100644 index 00000000000..31084384c31 --- /dev/null +++ b/tests/components/test_config_entry_example.py @@ -0,0 +1,38 @@ +"""Test the config entry example component.""" +import asyncio + +from homeassistant import config_entries + + +@asyncio.coroutine +def test_flow_works(hass): + """Test that the config flow works.""" + result = yield from hass.config_entries.flow.async_init( + 'config_entry_example') + + assert result['type'] == config_entries.RESULT_TYPE_FORM + + result = yield from hass.config_entries.flow.async_configure( + result['flow_id'], { + 'object_id': 'bla' + }) + + assert result['type'] == config_entries.RESULT_TYPE_FORM + + result = yield from hass.config_entries.flow.async_configure( + result['flow_id'], { + 'name': 'Hello' + }) + + assert result['type'] == config_entries.RESULT_TYPE_CREATE_ENTRY + state = hass.states.get('config_entry_example.bla') + assert state is not None + assert state.name == 'Hello' + assert 'config_entry_example' in hass.config.components + assert len(hass.config_entries.async_entries()) == 1 + + # Test removing entry. + entry = hass.config_entries.async_entries()[0] + yield from hass.config_entries.async_remove(entry.entry_id) + state = hass.states.get('config_entry_example.bla') + assert state is None diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py new file mode 100644 index 00000000000..3a1fe1d9d3e --- /dev/null +++ b/tests/test_config_entries.py @@ -0,0 +1,397 @@ +"""Test the config manager.""" +import asyncio +from unittest.mock import MagicMock, patch, mock_open + +import pytest +import voluptuous as vol + +from homeassistant import config_entries, loader +from homeassistant.setup import async_setup_component + +from tests.common import MockModule, mock_coro, MockConfigEntry + + +@pytest.fixture +def manager(hass): + """Fixture of a loaded config manager.""" + manager = config_entries.ConfigEntries(hass, {}) + manager._entries = [] + hass.config_entries = manager + return manager + + +@asyncio.coroutine +def test_call_setup_entry(hass): + """Test we call .setup_entry.""" + MockConfigEntry(domain='comp').add_to_hass(hass) + + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component( + 'comp', + MockModule('comp', async_setup_entry=mock_setup_entry)) + + result = yield from async_setup_component(hass, 'comp', {}) + assert result + assert len(mock_setup_entry.mock_calls) == 1 + + +@asyncio.coroutine +def test_remove_entry(manager): + """Test that we can remove an entry.""" + mock_unload_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component( + 'test', + MockModule('comp', async_unload_entry=mock_unload_entry)) + + MockConfigEntry(domain='test', entry_id='test1').add_to_manager(manager) + MockConfigEntry(domain='test', entry_id='test2').add_to_manager(manager) + MockConfigEntry(domain='test', entry_id='test3').add_to_manager(manager) + + assert [item.entry_id for item in manager.async_entries()] == \ + ['test1', 'test2', 'test3'] + + result = yield from manager.async_remove('test2') + + assert result == { + 'require_restart': False + } + assert [item.entry_id for item in manager.async_entries()] == \ + ['test1', 'test3'] + + assert len(mock_unload_entry.mock_calls) == 1 + + +@asyncio.coroutine +def test_remove_entry_raises(manager): + """Test if a component raises while removing entry.""" + @asyncio.coroutine + def mock_unload_entry(hass, entry): + """Mock unload entry function.""" + raise Exception("BROKEN") + + loader.set_component( + 'test', + MockModule('comp', async_unload_entry=mock_unload_entry)) + + MockConfigEntry(domain='test', entry_id='test1').add_to_manager(manager) + MockConfigEntry(domain='test', entry_id='test2').add_to_manager(manager) + MockConfigEntry(domain='test', entry_id='test3').add_to_manager(manager) + + assert [item.entry_id for item in manager.async_entries()] == \ + ['test1', 'test2', 'test3'] + + result = yield from manager.async_remove('test2') + + assert result == { + 'require_restart': True + } + assert [item.entry_id for item in manager.async_entries()] == \ + ['test1', 'test3'] + + +@asyncio.coroutine +def test_add_entry_calls_setup_entry(hass, manager): + """Test we call setup_config_entry.""" + mock_setup_entry = MagicMock(return_value=mock_coro(True)) + + loader.set_component( + 'comp', + MockModule('comp', async_setup_entry=mock_setup_entry)) + + class TestFlow(config_entries.ConfigFlowHandler): + + VERSION = 1 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_create_entry( + title='title', + data={ + 'token': 'supersecret' + }) + + with patch.dict(config_entries.HANDLERS, {'comp': TestFlow}): + yield from manager.flow.async_init('comp') + yield from hass.async_block_till_done() + + assert len(mock_setup_entry.mock_calls) == 1 + p_hass, p_entry = mock_setup_entry.mock_calls[0][1] + + assert p_hass is hass + assert p_entry.data == { + 'token': 'supersecret' + } + + +@asyncio.coroutine +def test_entries_gets_entries(manager): + """Test entries are filtered by domain.""" + MockConfigEntry(domain='test').add_to_manager(manager) + entry1 = MockConfigEntry(domain='test2') + entry1.add_to_manager(manager) + entry2 = MockConfigEntry(domain='test2') + entry2.add_to_manager(manager) + + assert manager.async_entries('test2') == [entry1, entry2] + + +@asyncio.coroutine +def test_domains_gets_uniques(manager): + """Test we only return each domain once.""" + MockConfigEntry(domain='test').add_to_manager(manager) + MockConfigEntry(domain='test2').add_to_manager(manager) + MockConfigEntry(domain='test2').add_to_manager(manager) + MockConfigEntry(domain='test').add_to_manager(manager) + MockConfigEntry(domain='test3').add_to_manager(manager) + + assert manager.async_domains() == ['test', 'test2', 'test3'] + + +@asyncio.coroutine +def test_saving_and_loading(hass): + """Test that we're saving and loading correctly.""" + class TestFlow(config_entries.ConfigFlowHandler): + VERSION = 5 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_create_entry( + title='Test Title', + data={ + 'token': 'abcd' + } + ) + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + yield from hass.config_entries.flow.async_init('test') + + class Test2Flow(config_entries.ConfigFlowHandler): + VERSION = 3 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_create_entry( + title='Test 2 Title', + data={ + 'username': 'bla' + } + ) + + json_path = 'homeassistant.util.json.open' + + with patch('homeassistant.config_entries.HANDLERS.get', + return_value=Test2Flow), \ + patch.object(config_entries, 'SAVE_DELAY', 0): + yield from hass.config_entries.flow.async_init('test') + + with patch(json_path, mock_open(), create=True) as mock_write: + # To trigger the call_later + yield from asyncio.sleep(0, loop=hass.loop) + # To execute the save + yield from hass.async_block_till_done() + + # Mock open calls are: open file, context enter, write, context leave + written = mock_write.mock_calls[2][1][0] + + # Now load written data in new config manager + manager = config_entries.ConfigEntries(hass, {}) + + with patch('os.path.isfile', return_value=True), \ + patch(json_path, mock_open(read_data=written), create=True): + yield from manager.async_load() + + # Ensure same order + for orig, loaded in zip(hass.config_entries.async_entries(), + manager.async_entries()): + assert orig.version == loaded.version + assert orig.domain == loaded.domain + assert orig.title == loaded.title + assert orig.data == loaded.data + assert orig.source == loaded.source + + +####################### +# FLOW MANAGER TESTS # +####################### + +@asyncio.coroutine +def test_configure_reuses_handler_instance(manager): + """Test that we reuse instances.""" + class TestFlow(config_entries.ConfigFlowHandler): + handle_count = 0 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + self.handle_count += 1 + return self.async_show_form( + title=str(self.handle_count), + step_id='init') + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + form = yield from manager.flow.async_init('test') + assert form['title'] == '1' + form = yield from manager.flow.async_configure(form['flow_id']) + assert form['title'] == '2' + assert len(manager.flow.async_progress()) == 1 + assert len(manager.async_entries()) == 0 + + +@asyncio.coroutine +def test_configure_two_steps(manager): + """Test that we reuse instances.""" + class TestFlow(config_entries.ConfigFlowHandler): + VERSION = 1 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + if user_input is not None: + self.init_data = user_input + return self.async_step_second() + return self.async_show_form( + title='title', + step_id='init', + data_schema=vol.Schema([str]) + ) + + @asyncio.coroutine + def async_step_second(self, user_input=None): + if user_input is not None: + return self.async_create_entry( + title='Test Entry', + data=self.init_data + user_input + ) + return self.async_show_form( + title='title', + step_id='second', + data_schema=vol.Schema([str]) + ) + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + form = yield from manager.flow.async_init('test') + + with pytest.raises(vol.Invalid): + form = yield from manager.flow.async_configure( + form['flow_id'], 'INCORRECT-DATA') + + form = yield from manager.flow.async_configure( + form['flow_id'], ['INIT-DATA']) + form = yield from manager.flow.async_configure( + form['flow_id'], ['SECOND-DATA']) + assert form['type'] == config_entries.RESULT_TYPE_CREATE_ENTRY + assert len(manager.flow.async_progress()) == 0 + assert len(manager.async_entries()) == 1 + entry = manager.async_entries()[0] + assert entry.domain == 'test' + assert entry.data == ['INIT-DATA', 'SECOND-DATA'] + + +@asyncio.coroutine +def test_show_form(manager): + """Test that abort removes the flow from progress.""" + schema = vol.Schema({ + vol.Required('username'): str, + vol.Required('password'): str + }) + + class TestFlow(config_entries.ConfigFlowHandler): + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_show_form( + title='Hello form', + step_id='init', + description='test-description', + data_schema=schema, + errors={ + 'username': 'Should be unique.' + } + ) + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + form = yield from manager.flow.async_init('test') + assert form['type'] == 'form' + assert form['title'] == 'Hello form' + assert form['description'] == 'test-description' + assert form['data_schema'] is schema + assert form['errors'] == { + 'username': 'Should be unique.' + } + + +@asyncio.coroutine +def test_abort_removes_instance(manager): + """Test that abort removes the flow from progress.""" + class TestFlow(config_entries.ConfigFlowHandler): + is_new = True + + @asyncio.coroutine + def async_step_init(self, user_input=None): + old = self.is_new + self.is_new = False + return self.async_abort(reason=str(old)) + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + form = yield from manager.flow.async_init('test') + assert form['reason'] == 'True' + assert len(manager.flow.async_progress()) == 0 + assert len(manager.async_entries()) == 0 + form = yield from manager.flow.async_init('test') + assert form['reason'] == 'True' + assert len(manager.flow.async_progress()) == 0 + assert len(manager.async_entries()) == 0 + + +@asyncio.coroutine +def test_create_saves_data(manager): + """Test creating a config entry.""" + class TestFlow(config_entries.ConfigFlowHandler): + VERSION = 5 + + @asyncio.coroutine + def async_step_init(self, user_input=None): + return self.async_create_entry( + title='Test Title', + data='Test Data' + ) + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + yield from manager.flow.async_init('test') + assert len(manager.flow.async_progress()) == 0 + assert len(manager.async_entries()) == 1 + + entry = manager.async_entries()[0] + assert entry.version == 5 + assert entry.domain == 'test' + assert entry.title == 'Test Title' + assert entry.data == 'Test Data' + assert entry.source == config_entries.SOURCE_USER + + +@asyncio.coroutine +def test_discovery_init_flow(manager): + """Test a flow initialized by discovery.""" + class TestFlow(config_entries.ConfigFlowHandler): + VERSION = 5 + + @asyncio.coroutine + def async_step_discovery(self, info): + return self.async_create_entry(title=info['id'], data=info) + + data = { + 'id': 'hello', + 'token': 'secret' + } + + with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): + yield from manager.flow.async_init( + 'test', source=config_entries.SOURCE_DISCOVERY, data=data) + assert len(manager.flow.async_progress()) == 0 + assert len(manager.async_entries()) == 1 + + entry = manager.async_entries()[0] + assert entry.version == 5 + assert entry.domain == 'test' + assert entry.title == 'hello' + assert entry.data == data + assert entry.source == config_entries.SOURCE_DISCOVERY