From 60508f72158bcf5fd1b553f03a2d8a15917bc085 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 13 Apr 2018 10:14:53 -0400 Subject: [PATCH] Extract config flow to own module (#13840) * Extract config flow to own module * Lint * fix lint * fix typo * ConfigFlowHandler -> FlowHandler * Rename to data_entry_flow --- .../components/config/config_entries.py | 20 +- homeassistant/components/deconz/__init__.py | 4 +- homeassistant/components/discovery.py | 4 +- homeassistant/components/hue/config_flow.py | 4 +- homeassistant/config_entries.py | 199 +++--------------- homeassistant/data_entry_flow.py | 180 ++++++++++++++++ tests/common.py | 4 +- .../components/config/test_config_entries.py | 21 +- tests/components/test_discovery.py | 4 +- tests/test_config_entries.py | 190 +---------------- tests/test_data_entry_flow.py | 186 ++++++++++++++++ 11 files changed, 428 insertions(+), 388 deletions(-) create mode 100644 homeassistant/data_entry_flow.py create mode 100644 tests/test_data_entry_flow.py diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index aa42325b75b..967317134c2 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -3,7 +3,7 @@ import asyncio import voluptuous as vol -from homeassistant import config_entries +from homeassistant import config_entries, data_entry_flow from homeassistant.components.http import HomeAssistantView from homeassistant.components.http.data_validator import RequestDataValidator @@ -24,7 +24,7 @@ def async_setup(hass): def _prepare_json(result): """Convert result for JSON.""" - if result['type'] != config_entries.RESULT_TYPE_FORM: + if result['type'] != data_entry_flow.RESULT_TYPE_FORM: return result import voluptuous_serialize @@ -94,8 +94,8 @@ class ConfigManagerFlowIndexView(HomeAssistantView): hass = request.app['hass'] return self.json([ - flow for flow in hass.config_entries.flow.async_progress() - if flow['source'] != config_entries.SOURCE_USER]) + flw for flw in hass.config_entries.flow.async_progress() + if flw['source'] != data_entry_flow.SOURCE_USER]) @RequestDataValidator(vol.Schema({ vol.Required('domain'): str, @@ -108,9 +108,9 @@ class ConfigManagerFlowIndexView(HomeAssistantView): try: result = yield from hass.config_entries.flow.async_init( data['domain']) - except config_entries.UnknownHandler: + except data_entry_flow.UnknownHandler: return self.json_message('Invalid handler specified', 404) - except config_entries.UnknownStep: + except data_entry_flow.UnknownStep: return self.json_message('Handler does not support init', 400) result = _prepare_json(result) @@ -126,13 +126,13 @@ class ConfigManagerFlowResourceView(HomeAssistantView): @asyncio.coroutine def get(self, request, flow_id): - """Get the current state of a flow.""" + """Get the current state of a data_entry_flow.""" hass = request.app['hass'] try: result = yield from hass.config_entries.flow.async_configure( flow_id) - except config_entries.UnknownFlow: + except data_entry_flow.UnknownFlow: return self.json_message('Invalid flow specified', 404) result = _prepare_json(result) @@ -148,7 +148,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView): try: result = yield from hass.config_entries.flow.async_configure( flow_id, data) - except config_entries.UnknownFlow: + except data_entry_flow.UnknownFlow: return self.json_message('Invalid flow specified', 404) except vol.Invalid: return self.json_message('User input malformed', 400) @@ -164,7 +164,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView): try: hass.config_entries.flow.async_abort(flow_id) - except config_entries.UnknownFlow: + except data_entry_flow.UnknownFlow: return self.json_message('Invalid flow specified', 404) return self.json_message('Flow aborted') diff --git a/homeassistant/components/deconz/__init__.py b/homeassistant/components/deconz/__init__.py index 85ba271ec3a..04cd42ca620 100644 --- a/homeassistant/components/deconz/__init__.py +++ b/homeassistant/components/deconz/__init__.py @@ -8,7 +8,7 @@ import logging import voluptuous as vol -from homeassistant import config_entries +from homeassistant import config_entries, data_entry_flow from homeassistant.components.discovery import SERVICE_DECONZ from homeassistant.const import ( CONF_API_KEY, CONF_HOST, CONF_PORT, EVENT_HOMEASSISTANT_STOP) @@ -191,7 +191,7 @@ async def async_request_configuration(hass, config, deconz_config): @config_entries.HANDLERS.register(DOMAIN) -class DeconzFlowHandler(config_entries.ConfigFlowHandler): +class DeconzFlowHandler(data_entry_flow.FlowHandler): """Handle a deCONZ config flow.""" VERSION = 1 diff --git a/homeassistant/components/discovery.py b/homeassistant/components/discovery.py index 677a13d6a9d..693cd3d90f1 100644 --- a/homeassistant/components/discovery.py +++ b/homeassistant/components/discovery.py @@ -13,7 +13,7 @@ import os import voluptuous as vol -from homeassistant import config_entries +from homeassistant import data_entry_flow from homeassistant.core import callback from homeassistant.const import EVENT_HOMEASSISTANT_START import homeassistant.helpers.config_validation as cv @@ -119,7 +119,7 @@ async def async_setup(hass, config): if service in CONFIG_ENTRY_HANDLERS: await hass.config_entries.flow.async_init( CONFIG_ENTRY_HANDLERS[service], - source=config_entries.SOURCE_DISCOVERY, + source=data_entry_flow.SOURCE_DISCOVERY, data=info ) return diff --git a/homeassistant/components/hue/config_flow.py b/homeassistant/components/hue/config_flow.py index 11e399c984d..af67a594495 100644 --- a/homeassistant/components/hue/config_flow.py +++ b/homeassistant/components/hue/config_flow.py @@ -6,7 +6,7 @@ import os import async_timeout import voluptuous as vol -from homeassistant import config_entries +from homeassistant import config_entries, data_entry_flow from homeassistant.core import callback from homeassistant.helpers import aiohttp_client @@ -41,7 +41,7 @@ def _find_username_from_config(hass, filename): @config_entries.HANDLERS.register(DOMAIN) -class HueFlowHandler(config_entries.ConfigFlowHandler): +class HueFlowHandler(data_entry_flow.FlowHandler): """Handle a Hue config flow.""" VERSION = 1 diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index e2e45cb5819..d06bf8f1f8f 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -27,7 +27,7 @@ At a minimum, each config flow will have to define a version number and the 'init' step. @config_entries.HANDLERS.register(DOMAIN) - class ExampleConfigFlow(config_entries.ConfigFlowHandler): + class ExampleConfigFlow(config_entries.FlowHandler): VERSION = 1 @@ -117,6 +117,7 @@ import uuid from .core import callback from .exceptions import HomeAssistantError +from .data_entry_flow import FlowManager from .setup import async_setup_component, async_process_deps_reqs from .util.json import load_json, save_json from .util.decorator import Registry @@ -130,17 +131,11 @@ FLOWS = [ 'hue', ] -SOURCE_USER = 'user' -SOURCE_DISCOVERY = 'discovery' PATH_CONFIG = '.config_entries.json' SAVE_DELAY = 1 -RESULT_TYPE_FORM = 'form' -RESULT_TYPE_CREATE_ENTRY = 'create_entry' -RESULT_TYPE_ABORT = 'abort' - ENTRY_STATE_LOADED = 'loaded' ENTRY_STATE_SETUP_ERROR = 'setup_error' ENTRY_STATE_NOT_LOADED = 'not_loaded' @@ -251,18 +246,6 @@ class UnknownEntry(ConfigError): """Unknown entry specified.""" -class UnknownHandler(ConfigError): - """Unknown handler specified.""" - - -class UnknownFlow(ConfigError): - """Uknown flow specified.""" - - -class UnknownStep(ConfigError): - """Unknown step specified.""" - - class ConfigEntries: """Manage the configuration entries. @@ -272,7 +255,8 @@ class ConfigEntries: def __init__(self, hass, hass_config): """Initialize the entry manager.""" self.hass = hass - self.flow = FlowManager(hass, hass_config, self._async_add_entry) + self.flow = FlowManager(hass, HANDLERS, self._async_missing_handler, + self._async_save_entry) self._hass_config = hass_config self._entries = None self._sched_save = None @@ -357,8 +341,15 @@ class ConfigEntries: await entry.async_unload( self.hass, component=getattr(self.hass.components, component)) - async def _async_add_entry(self, entry): + async def _async_save_entry(self, result): """Add an entry.""" + entry = ConfigEntry( + version=result['version'], + domain=result['domain'], + title=result['title'], + data=result['data'], + source=result['source'], + ) self._entries.append(entry) self._async_schedule_save() @@ -371,6 +362,18 @@ class ConfigEntries: await async_setup_component( self.hass, entry.domain, self._hass_config) + async def _async_missing_handler(self, domain): + """Called when a flow handler is not loaded.""" + # This will load the component and thus register the handler + component = getattr(self.hass.components, domain) + + if domain not in HANDLERS: + return + + # Make sure requirements and dependencies of component are resolved + await async_process_deps_reqs( + self.hass, self._hass_config, domain, component) + @callback def _async_schedule_save(self): """Schedule saving the entity registry.""" @@ -388,157 +391,3 @@ class ConfigEntries: await self.hass.async_add_job( save_json, self.hass.config.path(PATH_CONFIG), data) - - -class FlowManager: - """Manage all the config flows that are in progress.""" - - def __init__(self, hass, hass_config, async_add_entry): - """Initialize the flow manager.""" - self.hass = hass - self._hass_config = hass_config - self._progress = {} - self._async_add_entry = async_add_entry - - @callback - def async_progress(self): - """Return the flows in progress.""" - return [{ - 'flow_id': flow.flow_id, - 'domain': flow.domain, - 'source': flow.source, - } for flow in self._progress.values()] - - async def async_init(self, domain, *, source=SOURCE_USER, data=None): - """Start a configuration flow.""" - handler = HANDLERS.get(domain) - - if handler is None: - # This will load the component and thus register the handler - component = getattr(self.hass.components, domain) - handler = HANDLERS.get(domain) - - if handler is None: - raise UnknownHandler - - # Make sure requirements and dependencies of component are resolved - await async_process_deps_reqs( - self.hass, self._hass_config, domain, component) - - flow_id = uuid.uuid4().hex - flow = self._progress[flow_id] = handler() - flow.hass = self.hass - flow.domain = domain - flow.flow_id = flow_id - flow.source = source - - if source == SOURCE_USER: - step = 'init' - else: - step = source - - return await self._async_handle_step(flow, step, data) - - async def async_configure(self, flow_id, user_input=None): - """Start or continue a configuration flow.""" - flow = self._progress.get(flow_id) - - if flow is None: - raise UnknownFlow - - step_id, data_schema = flow.cur_step - - if data_schema is not None and user_input is not None: - user_input = data_schema(user_input) - - return await self._async_handle_step( - flow, step_id, user_input) - - @callback - def async_abort(self, flow_id): - """Abort a flow.""" - if self._progress.pop(flow_id, None) is None: - raise UnknownFlow - - async def _async_handle_step(self, flow, step_id, user_input): - """Handle a step of a flow.""" - method = "async_step_{}".format(step_id) - - if not hasattr(flow, method): - self._progress.pop(flow.flow_id) - raise UnknownStep("Handler {} doesn't support step {}".format( - flow.__class__.__name__, step_id)) - - result = await getattr(flow, method)(user_input) - - if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY, - RESULT_TYPE_ABORT): - raise ValueError( - 'Handler returned incorrect type: {}'.format(result['type'])) - - if result['type'] == RESULT_TYPE_FORM: - flow.cur_step = (result['step_id'], result['data_schema']) - return result - - # Abort and Success results both finish the flow - self._progress.pop(flow.flow_id) - - if result['type'] == RESULT_TYPE_ABORT: - return result - - entry = ConfigEntry( - version=flow.VERSION, - domain=flow.domain, - title=result['title'], - data=result.pop('data'), - source=flow.source - ) - await self._async_add_entry(entry) - return result - - -class ConfigFlowHandler: - """Handle the configuration flow of a component.""" - - # Set by flow manager - flow_id = None - hass = None - domain = None - source = SOURCE_USER - cur_step = None - - # Set by dev - # VERSION - - @callback - def async_show_form(self, *, step_id, data_schema=None, errors=None): - """Return the definition of a form to gather user input.""" - return { - 'type': RESULT_TYPE_FORM, - 'flow_id': self.flow_id, - 'domain': self.domain, - 'step_id': step_id, - 'data_schema': data_schema, - 'errors': errors, - } - - @callback - def async_create_entry(self, *, title, data): - """Finish config flow and create a config entry.""" - return { - 'type': RESULT_TYPE_CREATE_ENTRY, - 'flow_id': self.flow_id, - 'domain': self.domain, - 'title': title, - 'data': data, - } - - @callback - def async_abort(self, *, reason): - """Abort the config flow.""" - return { - 'type': RESULT_TYPE_ABORT, - 'flow_id': self.flow_id, - 'domain': self.domain, - 'reason': reason - } diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py new file mode 100644 index 00000000000..5644481210c --- /dev/null +++ b/homeassistant/data_entry_flow.py @@ -0,0 +1,180 @@ +"""Classes to help gather user submissions.""" +import logging +import uuid + +from .core import callback +from .exceptions import HomeAssistantError + +_LOGGER = logging.getLogger(__name__) + +SOURCE_USER = 'user' +SOURCE_DISCOVERY = 'discovery' + +RESULT_TYPE_FORM = 'form' +RESULT_TYPE_CREATE_ENTRY = 'create_entry' +RESULT_TYPE_ABORT = 'abort' + + +class FlowError(HomeAssistantError): + """Error while configuring an account.""" + + +class UnknownHandler(FlowError): + """Unknown handler specified.""" + + +class UnknownFlow(FlowError): + """Uknown flow specified.""" + + +class UnknownStep(FlowError): + """Unknown step specified.""" + + +class FlowManager: + """Manage all the flows that are in progress.""" + + def __init__(self, hass, handlers, async_missing_handler, + async_save_entry): + """Initialize the flow manager.""" + self.hass = hass + self._handlers = handlers + self._progress = {} + self._async_missing_handler = async_missing_handler + self._async_save_entry = async_save_entry + + @callback + def async_progress(self): + """Return the flows in progress.""" + return [{ + 'flow_id': flow.flow_id, + 'domain': flow.domain, + 'source': flow.source, + } for flow in self._progress.values()] + + async def async_init(self, domain, *, source=SOURCE_USER, data=None): + """Start a configuration flow.""" + handler = self._handlers.get(domain) + + if handler is None: + await self._async_missing_handler(domain) + handler = self._handlers.get(domain) + + if handler is None: + raise UnknownHandler + + flow_id = uuid.uuid4().hex + flow = self._progress[flow_id] = handler() + flow.hass = self.hass + flow.domain = domain + flow.flow_id = flow_id + flow.source = source + + if source == SOURCE_USER: + step = 'init' + else: + step = source + + return await self._async_handle_step(flow, step, data) + + async def async_configure(self, flow_id, user_input=None): + """Start or continue a configuration flow.""" + flow = self._progress.get(flow_id) + + if flow is None: + raise UnknownFlow + + step_id, data_schema = flow.cur_step + + if data_schema is not None and user_input is not None: + user_input = data_schema(user_input) + + return await self._async_handle_step( + flow, step_id, user_input) + + @callback + def async_abort(self, flow_id): + """Abort a flow.""" + if self._progress.pop(flow_id, None) is None: + raise UnknownFlow + + async def _async_handle_step(self, flow, step_id, user_input): + """Handle a step of a flow.""" + method = "async_step_{}".format(step_id) + + if not hasattr(flow, method): + self._progress.pop(flow.flow_id) + raise UnknownStep("Handler {} doesn't support step {}".format( + flow.__class__.__name__, step_id)) + + result = await getattr(flow, method)(user_input) + + if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY, + RESULT_TYPE_ABORT): + raise ValueError( + 'Handler returned incorrect type: {}'.format(result['type'])) + + if result['type'] == RESULT_TYPE_FORM: + flow.cur_step = (result['step_id'], result['data_schema']) + return result + + # Abort and Success results both finish the flow + self._progress.pop(flow.flow_id) + + if result['type'] == RESULT_TYPE_ABORT: + return result + + # We pass a copy of the result because we're going to mutate our + # version afterwards and don't want to cause unexpected bugs. + await self._async_save_entry(dict(result)) + result.pop('data') + return result + + +class FlowHandler: + """Handle the configuration flow of a component.""" + + # Set by flow manager + flow_id = None + hass = None + domain = None + source = SOURCE_USER + cur_step = None + + # Set by developer + VERSION = 1 + + @callback + def async_show_form(self, *, step_id, data_schema=None, errors=None): + """Return the definition of a form to gather user input.""" + return { + 'type': RESULT_TYPE_FORM, + 'flow_id': self.flow_id, + 'domain': self.domain, + 'step_id': step_id, + 'data_schema': data_schema, + 'errors': errors, + } + + @callback + def async_create_entry(self, *, title, data): + """Finish config flow and create a config entry.""" + return { + 'version': self.VERSION, + 'type': RESULT_TYPE_CREATE_ENTRY, + 'flow_id': self.flow_id, + 'domain': self.domain, + 'title': title, + 'data': data, + 'source': self.source, + } + + @callback + def async_abort(self, *, reason): + """Abort the config flow.""" + return { + 'type': RESULT_TYPE_ABORT, + 'flow_id': self.flow_id, + 'domain': self.domain, + 'reason': reason + } diff --git a/tests/common.py b/tests/common.py index 54c214da4e9..67fd8bab23f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,7 +10,7 @@ import logging import threading from contextlib import contextmanager -from homeassistant import core as ha, loader, config_entries +from homeassistant import core as ha, loader, data_entry_flow, config_entries from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config from homeassistant.helpers import ( @@ -455,7 +455,7 @@ class MockConfigEntry(config_entries.ConfigEntry): """Helper for creating config entries that adds some defaults.""" def __init__(self, *, domain='test', data=None, version=0, entry_id=None, - source=config_entries.SOURCE_USER, title='Mock Title', + source=data_entry_flow.SOURCE_USER, title='Mock Title', state=None): """Initialize a mock config entry.""" kwargs = { diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index cfe6b12baac..d6490763951 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -8,7 +8,8 @@ import pytest import voluptuous as vol from homeassistant import config_entries as core_ce -from homeassistant.config_entries import ConfigFlowHandler, HANDLERS +from homeassistant.config_entries import HANDLERS +from homeassistant.data_entry_flow import FlowHandler from homeassistant.setup import async_setup_component from homeassistant.components.config import config_entries from homeassistant.loader import set_component @@ -93,7 +94,7 @@ def test_available_flows(hass, client): @asyncio.coroutine def test_initialize_flow(hass, client): """Test we can initialize a flow.""" - class TestFlow(ConfigFlowHandler): + class TestFlow(FlowHandler): @asyncio.coroutine def async_step_init(self, user_input=None): schema = OrderedDict() @@ -142,7 +143,7 @@ def test_initialize_flow(hass, client): @asyncio.coroutine def test_abort(hass, client): """Test a flow that aborts.""" - class TestFlow(ConfigFlowHandler): + class TestFlow(FlowHandler): @asyncio.coroutine def async_step_init(self, user_input=None): return self.async_abort(reason='bla') @@ -167,7 +168,7 @@ def test_create_account(hass, client): set_component( 'test', MockModule('test', async_setup_entry=mock_coro_func(True))) - class TestFlow(ConfigFlowHandler): + class TestFlow(FlowHandler): VERSION = 1 @asyncio.coroutine @@ -187,7 +188,9 @@ def test_create_account(hass, client): assert data == { 'domain': 'test', 'title': 'Test Entry', - 'type': 'create_entry' + 'type': 'create_entry', + 'source': 'user', + 'version': 1, } @@ -197,7 +200,7 @@ def test_two_step_flow(hass, client): set_component( 'test', MockModule('test', async_setup_entry=mock_coro_func(True))) - class TestFlow(ConfigFlowHandler): + class TestFlow(FlowHandler): VERSION = 1 @asyncio.coroutine @@ -245,13 +248,15 @@ def test_two_step_flow(hass, client): 'domain': 'test', 'type': 'create_entry', 'title': 'user-title', + 'version': 1, + 'source': 'user', } @asyncio.coroutine def test_get_progress_index(hass, client): """Test querying for the flows that are in progress.""" - class TestFlow(ConfigFlowHandler): + class TestFlow(FlowHandler): VERSION = 5 @asyncio.coroutine @@ -283,7 +288,7 @@ def test_get_progress_index(hass, client): @asyncio.coroutine def test_get_progress_flow(hass, client): """Test we can query the API for same result as we get from init a flow.""" - class TestFlow(ConfigFlowHandler): + class TestFlow(FlowHandler): @asyncio.coroutine def async_step_init(self, user_input=None): schema = OrderedDict() diff --git a/tests/components/test_discovery.py b/tests/components/test_discovery.py index b4c80bf3210..f3f63654e8b 100644 --- a/tests/components/test_discovery.py +++ b/tests/components/test_discovery.py @@ -5,7 +5,7 @@ from unittest.mock import patch, MagicMock import pytest -from homeassistant import config_entries +from homeassistant import data_entry_flow from homeassistant.bootstrap import async_setup_component from homeassistant.components import discovery from homeassistant.util.dt import utcnow @@ -174,5 +174,5 @@ async def test_discover_config_flow(hass): assert len(m_init.mock_calls) == 1 args, kwargs = m_init.mock_calls[0][1:] assert args == ('mock-component',) - assert kwargs['source'] == config_entries.SOURCE_DISCOVERY + assert kwargs['source'] == data_entry_flow.SOURCE_DISCOVERY assert kwargs['data'] == discovery_info diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index b9b39b11c13..94b1dcb47da 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3,9 +3,8 @@ import asyncio from unittest.mock import MagicMock, patch, mock_open import pytest -import voluptuous as vol -from homeassistant import config_entries, loader +from homeassistant import config_entries, loader, data_entry_flow from homeassistant.setup import async_setup_component from tests.common import MockModule, mock_coro, MockConfigEntry @@ -100,7 +99,7 @@ def test_add_entry_calls_setup_entry(hass, manager): 'comp', MockModule('comp', async_setup_entry=mock_setup_entry)) - class TestFlow(config_entries.ConfigFlowHandler): + class TestFlow(data_entry_flow.FlowHandler): VERSION = 1 @@ -112,7 +111,7 @@ def test_add_entry_calls_setup_entry(hass, manager): 'token': 'supersecret' }) - with patch.dict(config_entries.HANDLERS, {'comp': TestFlow}): + with patch.dict(config_entries.HANDLERS, {'comp': TestFlow, 'beer': 5}): yield from manager.flow.async_init('comp') yield from hass.async_block_till_done() @@ -152,7 +151,7 @@ def test_domains_gets_uniques(manager): @asyncio.coroutine def test_saving_and_loading(hass): """Test that we're saving and loading correctly.""" - class TestFlow(config_entries.ConfigFlowHandler): + class TestFlow(data_entry_flow.FlowHandler): VERSION = 5 @asyncio.coroutine @@ -167,7 +166,7 @@ def test_saving_and_loading(hass): with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): yield from hass.config_entries.flow.async_init('test') - class Test2Flow(config_entries.ConfigFlowHandler): + class Test2Flow(data_entry_flow.FlowHandler): VERSION = 3 @asyncio.coroutine @@ -212,185 +211,6 @@ def test_saving_and_loading(hass): assert orig.source == loaded.source -####################### -# FLOW MANAGER TESTS # -####################### - -@asyncio.coroutine -def test_configure_reuses_handler_instance(manager): - """Test that we reuse instances.""" - class TestFlow(config_entries.ConfigFlowHandler): - handle_count = 0 - - @asyncio.coroutine - def async_step_init(self, user_input=None): - self.handle_count += 1 - return self.async_show_form( - errors={'base': str(self.handle_count)}, - step_id='init') - - with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - form = yield from manager.flow.async_init('test') - assert form['errors']['base'] == '1' - form = yield from manager.flow.async_configure(form['flow_id']) - assert form['errors']['base'] == '2' - assert len(manager.flow.async_progress()) == 1 - assert len(manager.async_entries()) == 0 - - -@asyncio.coroutine -def test_configure_two_steps(manager): - """Test that we reuse instances.""" - class TestFlow(config_entries.ConfigFlowHandler): - VERSION = 1 - - @asyncio.coroutine - def async_step_init(self, user_input=None): - if user_input is not None: - self.init_data = user_input - return self.async_step_second() - return self.async_show_form( - step_id='init', - data_schema=vol.Schema([str]) - ) - - @asyncio.coroutine - def async_step_second(self, user_input=None): - if user_input is not None: - return self.async_create_entry( - title='Test Entry', - data=self.init_data + user_input - ) - return self.async_show_form( - step_id='second', - data_schema=vol.Schema([str]) - ) - - with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - form = yield from manager.flow.async_init('test') - - with pytest.raises(vol.Invalid): - form = yield from manager.flow.async_configure( - form['flow_id'], 'INCORRECT-DATA') - - form = yield from manager.flow.async_configure( - form['flow_id'], ['INIT-DATA']) - form = yield from manager.flow.async_configure( - form['flow_id'], ['SECOND-DATA']) - assert form['type'] == config_entries.RESULT_TYPE_CREATE_ENTRY - assert len(manager.flow.async_progress()) == 0 - assert len(manager.async_entries()) == 1 - entry = manager.async_entries()[0] - assert entry.domain == 'test' - assert entry.data == ['INIT-DATA', 'SECOND-DATA'] - - -@asyncio.coroutine -def test_show_form(manager): - """Test that abort removes the flow from progress.""" - schema = vol.Schema({ - vol.Required('username'): str, - vol.Required('password'): str - }) - - class TestFlow(config_entries.ConfigFlowHandler): - @asyncio.coroutine - def async_step_init(self, user_input=None): - return self.async_show_form( - step_id='init', - data_schema=schema, - errors={ - 'username': 'Should be unique.' - } - ) - - with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - form = yield from manager.flow.async_init('test') - assert form['type'] == 'form' - assert form['data_schema'] is schema - assert form['errors'] == { - 'username': 'Should be unique.' - } - - -@asyncio.coroutine -def test_abort_removes_instance(manager): - """Test that abort removes the flow from progress.""" - class TestFlow(config_entries.ConfigFlowHandler): - is_new = True - - @asyncio.coroutine - def async_step_init(self, user_input=None): - old = self.is_new - self.is_new = False - return self.async_abort(reason=str(old)) - - with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - form = yield from manager.flow.async_init('test') - assert form['reason'] == 'True' - assert len(manager.flow.async_progress()) == 0 - assert len(manager.async_entries()) == 0 - form = yield from manager.flow.async_init('test') - assert form['reason'] == 'True' - assert len(manager.flow.async_progress()) == 0 - assert len(manager.async_entries()) == 0 - - -@asyncio.coroutine -def test_create_saves_data(manager): - """Test creating a config entry.""" - class TestFlow(config_entries.ConfigFlowHandler): - VERSION = 5 - - @asyncio.coroutine - def async_step_init(self, user_input=None): - return self.async_create_entry( - title='Test Title', - data='Test Data' - ) - - with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - yield from manager.flow.async_init('test') - assert len(manager.flow.async_progress()) == 0 - assert len(manager.async_entries()) == 1 - - entry = manager.async_entries()[0] - assert entry.version == 5 - assert entry.domain == 'test' - assert entry.title == 'Test Title' - assert entry.data == 'Test Data' - assert entry.source == config_entries.SOURCE_USER - - -@asyncio.coroutine -def test_discovery_init_flow(manager): - """Test a flow initialized by discovery.""" - class TestFlow(config_entries.ConfigFlowHandler): - VERSION = 5 - - @asyncio.coroutine - def async_step_discovery(self, info): - return self.async_create_entry(title=info['id'], data=info) - - data = { - 'id': 'hello', - 'token': 'secret' - } - - with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): - yield from manager.flow.async_init( - 'test', source=config_entries.SOURCE_DISCOVERY, data=data) - assert len(manager.flow.async_progress()) == 0 - assert len(manager.async_entries()) == 1 - - entry = manager.async_entries()[0] - assert entry.version == 5 - assert entry.domain == 'test' - assert entry.title == 'hello' - assert entry.data == data - assert entry.source == config_entries.SOURCE_DISCOVERY - - async def test_forward_entry_sets_up_component(hass): """Test we setup the component entry is forwarded to.""" entry = MockConfigEntry(domain='original') diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py new file mode 100644 index 00000000000..f7067871174 --- /dev/null +++ b/tests/test_data_entry_flow.py @@ -0,0 +1,186 @@ +"""Test the flow classes.""" +import pytest +import voluptuous as vol + +from homeassistant import data_entry_flow +from homeassistant.util.decorator import Registry + +from tests.common import mock_coro + + +@pytest.fixture +def manager(): + """Return a flow manager.""" + handlers = Registry() + entries = [] + + async def async_add_entry(result): + entries.append(result) + + manager = data_entry_flow.FlowManager( + None, handlers, mock_coro, async_add_entry) + manager.mock_created_entries = entries + manager.mock_reg_handler = handlers.register + return manager + + +async def test_configure_reuses_handler_instance(manager): + """Test that we reuse instances.""" + @manager.mock_reg_handler('test') + class TestFlow(data_entry_flow.FlowHandler): + handle_count = 0 + + async def async_step_init(self, user_input=None): + self.handle_count += 1 + return self.async_show_form( + errors={'base': str(self.handle_count)}, + step_id='init') + + form = await manager.async_init('test') + assert form['errors']['base'] == '1' + form = await manager.async_configure(form['flow_id']) + assert form['errors']['base'] == '2' + assert len(manager.async_progress()) == 1 + assert len(manager.mock_created_entries) == 0 + + +async def test_configure_two_steps(manager): + """Test that we reuse instances.""" + @manager.mock_reg_handler('test') + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 1 + + async def async_step_init(self, user_input=None): + if user_input is not None: + self.init_data = user_input + return await self.async_step_second() + return self.async_show_form( + step_id='init', + data_schema=vol.Schema([str]) + ) + + async def async_step_second(self, user_input=None): + if user_input is not None: + return self.async_create_entry( + title='Test Entry', + data=self.init_data + user_input + ) + return self.async_show_form( + step_id='second', + data_schema=vol.Schema([str]) + ) + + form = await manager.async_init('test') + + with pytest.raises(vol.Invalid): + form = await manager.async_configure( + form['flow_id'], 'INCORRECT-DATA') + + form = await manager.async_configure( + form['flow_id'], ['INIT-DATA']) + form = await manager.async_configure( + form['flow_id'], ['SECOND-DATA']) + assert form['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 1 + result = manager.mock_created_entries[0] + assert result['domain'] == 'test' + assert result['data'] == ['INIT-DATA', 'SECOND-DATA'] + + +async def test_show_form(manager): + """Test that abort removes the flow from progress.""" + schema = vol.Schema({ + vol.Required('username'): str, + vol.Required('password'): str + }) + + @manager.mock_reg_handler('test') + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_show_form( + step_id='init', + data_schema=schema, + errors={ + 'username': 'Should be unique.' + } + ) + + form = await manager.async_init('test') + assert form['type'] == 'form' + assert form['data_schema'] is schema + assert form['errors'] == { + 'username': 'Should be unique.' + } + + +async def test_abort_removes_instance(manager): + """Test that abort removes the flow from progress.""" + @manager.mock_reg_handler('test') + class TestFlow(data_entry_flow.FlowHandler): + is_new = True + + async def async_step_init(self, user_input=None): + old = self.is_new + self.is_new = False + return self.async_abort(reason=str(old)) + + form = await manager.async_init('test') + assert form['reason'] == 'True' + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + form = await manager.async_init('test') + assert form['reason'] == 'True' + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + + +async def test_create_saves_data(manager): + """Test creating a config entry.""" + @manager.mock_reg_handler('test') + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + + async def async_step_init(self, user_input=None): + return self.async_create_entry( + title='Test Title', + data='Test Data' + ) + + await manager.async_init('test') + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 1 + + entry = manager.mock_created_entries[0] + assert entry['version'] == 5 + assert entry['domain'] == 'test' + assert entry['title'] == 'Test Title' + assert entry['data'] == 'Test Data' + assert entry['source'] == data_entry_flow.SOURCE_USER + + +async def test_discovery_init_flow(manager): + """Test a flow initialized by discovery.""" + @manager.mock_reg_handler('test') + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + + async def async_step_discovery(self, info): + return self.async_create_entry(title=info['id'], data=info) + + data = { + 'id': 'hello', + 'token': 'secret' + } + + await manager.async_init( + 'test', source=data_entry_flow.SOURCE_DISCOVERY, data=data) + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 1 + + entry = manager.mock_created_entries[0] + assert entry['version'] == 5 + assert entry['domain'] == 'test' + assert entry['title'] == 'hello' + assert entry['data'] == data + assert entry['source'] == data_entry_flow.SOURCE_DISCOVERY