From f58425dd3cc4f8f3c492102aa448c7b93bab7128 Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Thu, 9 Aug 2018 04:24:14 -0700 Subject: [PATCH] Refactor data entry flow (#15883) * Refactoring data_entry_flow and config_entry_flow Move SOURCE_* to config_entries Change data_entry_flow.FlowManager.async_init() source param default to None Change this first step_id as source or init if source is None _BaseFlowManagerView pass in SOURCE_USER as default source * First step of data entry flow decided by _async_create_flow() now * Lint * Change helpers.config_entry_flow.DiscoveryFlowHandler default step * Change FlowManager.async_init source param to context dict param --- homeassistant/auth/__init__.py | 2 +- homeassistant/components/cast/__init__.py | 4 +- .../components/config/config_entries.py | 2 +- homeassistant/components/deconz/__init__.py | 5 +- .../components/deconz/config_flow.py | 4 ++ homeassistant/components/discovery.py | 4 +- .../components/homematicip_cloud/__init__.py | 4 +- .../homematicip_cloud/config_flow.py | 4 ++ homeassistant/components/hue/__init__.py | 5 +- homeassistant/components/hue/bridge.py | 3 +- homeassistant/components/hue/config_flow.py | 4 ++ homeassistant/components/nest/__init__.py | 4 +- homeassistant/components/nest/config_flow.py | 4 ++ homeassistant/components/sonos/__init__.py | 4 +- homeassistant/components/zone/config_flow.py | 4 ++ homeassistant/config_entries.py | 48 +++++++++++++------ homeassistant/data_entry_flow.py | 24 ++++------ homeassistant/helpers/config_entry_flow.py | 2 +- homeassistant/helpers/data_entry_flow.py | 5 +- tests/common.py | 4 +- .../components/config/test_config_entries.py | 18 +++---- tests/components/test_discovery.py | 4 +- tests/helpers/test_config_entry_flow.py | 12 ++--- tests/test_config_entries.py | 10 ++-- tests/test_data_entry_flow.py | 23 +++++---- 25 files changed, 128 insertions(+), 79 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 35804cd8483..8eaa9cdbb97 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -211,7 +211,7 @@ class AuthManager: return tkn - async def _async_create_login_flow(self, handler, *, source, data): + async def _async_create_login_flow(self, handler, *, context, data): """Create a login flow.""" auth_provider = self._providers[handler] diff --git a/homeassistant/components/cast/__init__.py b/homeassistant/components/cast/__init__.py index aadf0103c5a..6885f24269a 100644 --- a/homeassistant/components/cast/__init__.py +++ b/homeassistant/components/cast/__init__.py @@ -1,5 +1,5 @@ """Component to embed Google Cast.""" -from homeassistant import data_entry_flow +from homeassistant import config_entries from homeassistant.helpers import config_entry_flow @@ -15,7 +15,7 @@ async def async_setup(hass, config): if conf is not None: hass.async_create_task(hass.config_entries.flow.async_init( - DOMAIN, source=data_entry_flow.SOURCE_IMPORT)) + DOMAIN, context={'source': config_entries.SOURCE_IMPORT})) return True diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 648f6ae9972..57fdbd31d20 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -96,7 +96,7 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView): return self.json([ flw for flw in hass.config_entries.flow.async_progress() - if flw['source'] != data_entry_flow.SOURCE_USER]) + if flw['source'] != config_entries.SOURCE_USER]) class ConfigManagerFlowResourceView(FlowManagerResourceView): diff --git a/homeassistant/components/deconz/__init__.py b/homeassistant/components/deconz/__init__.py index eacb31e3f8b..eacfe22e818 100644 --- a/homeassistant/components/deconz/__init__.py +++ b/homeassistant/components/deconz/__init__.py @@ -6,6 +6,7 @@ https://home-assistant.io/components/deconz/ """ import voluptuous as vol +from homeassistant import config_entries from homeassistant.const import ( CONF_API_KEY, CONF_EVENT, CONF_HOST, CONF_ID, CONF_PORT, EVENT_HOMEASSISTANT_STOP) @@ -60,7 +61,9 @@ async def async_setup(hass, config): deconz_config = config[DOMAIN] if deconz_config and not configured_hosts(hass): hass.async_add_job(hass.config_entries.flow.async_init( - DOMAIN, source='import', data=deconz_config + DOMAIN, + context={'source': config_entries.SOURCE_IMPORT}, + data=deconz_config )) return True diff --git a/homeassistant/components/deconz/config_flow.py b/homeassistant/components/deconz/config_flow.py index a6f67506227..fb2eb54232a 100644 --- a/homeassistant/components/deconz/config_flow.py +++ b/homeassistant/components/deconz/config_flow.py @@ -33,6 +33,10 @@ class DeconzFlowHandler(data_entry_flow.FlowHandler): self.bridges = [] self.deconz_config = {} + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + return await self.async_step_init(user_input) + async def async_step_init(self, user_input=None): """Handle a deCONZ config flow start. diff --git a/homeassistant/components/discovery.py b/homeassistant/components/discovery.py index 8272fa9814a..b400d1d8885 100644 --- a/homeassistant/components/discovery.py +++ b/homeassistant/components/discovery.py @@ -13,7 +13,7 @@ import os import voluptuous as vol -from homeassistant import data_entry_flow +from homeassistant import config_entries from homeassistant.core import callback from homeassistant.const import EVENT_HOMEASSISTANT_START import homeassistant.helpers.config_validation as cv @@ -138,7 +138,7 @@ async def async_setup(hass, config): if service in CONFIG_ENTRY_HANDLERS: await hass.config_entries.flow.async_init( CONFIG_ENTRY_HANDLERS[service], - source=data_entry_flow.SOURCE_DISCOVERY, + context={'source': config_entries.SOURCE_DISCOVERY}, data=info ) return diff --git a/homeassistant/components/homematicip_cloud/__init__.py b/homeassistant/components/homematicip_cloud/__init__.py index b9266322978..f2cc8f443ac 100644 --- a/homeassistant/components/homematicip_cloud/__init__.py +++ b/homeassistant/components/homematicip_cloud/__init__.py @@ -10,6 +10,7 @@ import logging import voluptuous as vol import homeassistant.helpers.config_validation as cv +from homeassistant import config_entries from .const import ( DOMAIN, HMIPC_HAPID, HMIPC_AUTHTOKEN, HMIPC_NAME, @@ -41,7 +42,8 @@ async def async_setup(hass, config): for conf in accesspoints: if conf[CONF_ACCESSPOINT] not in configured_haps(hass): hass.async_add_job(hass.config_entries.flow.async_init( - DOMAIN, source='import', data={ + DOMAIN, context={'source': config_entries.SOURCE_IMPORT}, + data={ HMIPC_HAPID: conf[CONF_ACCESSPOINT], HMIPC_AUTHTOKEN: conf[CONF_AUTHTOKEN], HMIPC_NAME: conf[CONF_NAME], diff --git a/homeassistant/components/homematicip_cloud/config_flow.py b/homeassistant/components/homematicip_cloud/config_flow.py index 3be89172e27..78970031d11 100644 --- a/homeassistant/components/homematicip_cloud/config_flow.py +++ b/homeassistant/components/homematicip_cloud/config_flow.py @@ -27,6 +27,10 @@ class HomematicipCloudFlowHandler(data_entry_flow.FlowHandler): """Initialize HomematicIP Cloud config flow.""" self.auth = None + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + return await self.async_step_init(user_input) + async def async_step_init(self, user_input=None): """Handle a flow start.""" errors = {} diff --git a/homeassistant/components/hue/__init__.py b/homeassistant/components/hue/__init__.py index dbd86ef31f3..c04380e1303 100644 --- a/homeassistant/components/hue/__init__.py +++ b/homeassistant/components/hue/__init__.py @@ -9,7 +9,7 @@ import logging import voluptuous as vol -from homeassistant import data_entry_flow +from homeassistant import config_entries from homeassistant.const import CONF_FILENAME, CONF_HOST from homeassistant.helpers import aiohttp_client, config_validation as cv @@ -108,7 +108,8 @@ async def async_setup(hass, config): # deadlock: creating a config entry will set up the component but the # setup would block till the entry is created! hass.async_add_job(hass.config_entries.flow.async_init( - DOMAIN, source=data_entry_flow.SOURCE_IMPORT, data={ + DOMAIN, context={'source': config_entries.SOURCE_IMPORT}, + data={ 'host': bridge_conf[CONF_HOST], 'path': bridge_conf[CONF_FILENAME], } diff --git a/homeassistant/components/hue/bridge.py b/homeassistant/components/hue/bridge.py index b7cf0e1de07..874c18aaa7e 100644 --- a/homeassistant/components/hue/bridge.py +++ b/homeassistant/components/hue/bridge.py @@ -51,7 +51,8 @@ class HueBridge: # linking procedure. When linking succeeds, it will remove the # old config entry. hass.async_add_job(hass.config_entries.flow.async_init( - DOMAIN, source='import', data={ + DOMAIN, context={'source': config_entries.SOURCE_IMPORT}, + data={ 'host': host, } )) diff --git a/homeassistant/components/hue/config_flow.py b/homeassistant/components/hue/config_flow.py index a7fe3ff04e0..49ebbdaabf5 100644 --- a/homeassistant/components/hue/config_flow.py +++ b/homeassistant/components/hue/config_flow.py @@ -50,6 +50,10 @@ class HueFlowHandler(data_entry_flow.FlowHandler): """Initialize the Hue flow.""" self.host = None + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + return await self.async_step_init(user_input) + async def async_step_init(self, user_input=None): """Handle a flow start.""" from aiohue.discovery import discover_nupnp diff --git a/homeassistant/components/nest/__init__.py b/homeassistant/components/nest/__init__.py index 1adb113bb81..de9783ba931 100644 --- a/homeassistant/components/nest/__init__.py +++ b/homeassistant/components/nest/__init__.py @@ -11,6 +11,7 @@ from datetime import datetime, timedelta import voluptuous as vol +from homeassistant import config_entries from homeassistant.const import ( CONF_STRUCTURE, CONF_FILENAME, CONF_BINARY_SENSORS, CONF_SENSORS, CONF_MONITORED_CONDITIONS, @@ -103,7 +104,8 @@ async def async_setup(hass, config): access_token_cache_file = hass.config.path(filename) hass.async_add_job(hass.config_entries.flow.async_init( - DOMAIN, source='import', data={ + DOMAIN, context={'source': config_entries.SOURCE_IMPORT}, + data={ 'nest_conf_path': access_token_cache_file, } )) diff --git a/homeassistant/components/nest/config_flow.py b/homeassistant/components/nest/config_flow.py index f97e0dc8ff5..c9987693b1a 100644 --- a/homeassistant/components/nest/config_flow.py +++ b/homeassistant/components/nest/config_flow.py @@ -58,6 +58,10 @@ class NestFlowHandler(data_entry_flow.FlowHandler): """Initialize the Nest config flow.""" self.flow_impl = None + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + return await self.async_step_init(user_input) + async def async_step_init(self, user_input=None): """Handle a flow start.""" flows = self.hass.data.get(DATA_FLOW_IMPL, {}) diff --git a/homeassistant/components/sonos/__init__.py b/homeassistant/components/sonos/__init__.py index 4c5592c02c2..bbc05a3aa61 100644 --- a/homeassistant/components/sonos/__init__.py +++ b/homeassistant/components/sonos/__init__.py @@ -1,5 +1,5 @@ """Component to embed Sonos.""" -from homeassistant import data_entry_flow +from homeassistant import config_entries from homeassistant.helpers import config_entry_flow @@ -15,7 +15,7 @@ async def async_setup(hass, config): if conf is not None: hass.async_create_task(hass.config_entries.flow.async_init( - DOMAIN, source=data_entry_flow.SOURCE_IMPORT)) + DOMAIN, context={'source': config_entries.SOURCE_IMPORT})) return True diff --git a/homeassistant/components/zone/config_flow.py b/homeassistant/components/zone/config_flow.py index 5ec955a48d9..01577de4c8f 100644 --- a/homeassistant/components/zone/config_flow.py +++ b/homeassistant/components/zone/config_flow.py @@ -29,6 +29,10 @@ class ZoneFlowHandler(data_entry_flow.FlowHandler): """Initialize zone configuration flow.""" pass + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + return await self.async_step_init(user_input) + async def async_step_init(self, user_input=None): """Handle a flow start.""" errors = {} diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 12420e989ee..51114a2a416 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -24,20 +24,24 @@ Before instantiating the handler, Home Assistant will make sure to load all dependencies and install the requirements of the component. At a minimum, each config flow will have to define a version number and the -'init' step. +'user' step. @config_entries.HANDLERS.register(DOMAIN) - class ExampleConfigFlow(config_entries.FlowHandler): + class ExampleConfigFlow(data_entry_flow.FlowHandler): VERSION = 1 - async def async_step_init(self, user_input=None): + async def async_step_user(self, user_input=None): … -The 'init' step is the first step of a flow and is called when a user +The 'user' step is the first step of a flow and is called when a user starts a new flow. Each step has three different possible results: "Show Form", "Abort" and "Create Entry". +> Note: prior 0.76, the default step is 'init' step, some config flows still +keep 'init' step to avoid break localization. All new config flow should use +'user' step. + ### Show Form This will show a form to the user to fill in. You define the current step, @@ -50,7 +54,7 @@ a title, a description and the schema of the data that needs to be returned. data_schema[vol.Required('password')] = str return self.async_show_form( - step_id='init', + step_id='user', title='Account Info', data_schema=vol.Schema(data_schema) ) @@ -97,10 +101,10 @@ Assistant, a success message is shown to the user and the flow is finished. You might want to initialize a config flow programmatically. For example, if we discover a device on the network that requires user interaction to finish setup. To do so, pass a source parameter and optional user input to the init -step: +method: await hass.config_entries.flow.async_init( - 'hue', source='discovery', data=discovery_info) + 'hue', context={'source': 'discovery'}, data=discovery_info) The config flow handler will need to add a step to support the source. The step should follow the same return values as a normal step. @@ -123,6 +127,11 @@ from homeassistant.util.decorator import Registry _LOGGER = logging.getLogger(__name__) + +SOURCE_USER = 'user' +SOURCE_DISCOVERY = 'discovery' +SOURCE_IMPORT = 'import' + HANDLERS = Registry() # Components that have config flows. In future we will auto-generate this list. FLOWS = [ @@ -151,8 +160,8 @@ ENTRY_STATE_FAILED_UNLOAD = 'failed_unload' DISCOVERY_NOTIFICATION_ID = 'config_entry_discovery' DISCOVERY_SOURCES = ( - data_entry_flow.SOURCE_DISCOVERY, - data_entry_flow.SOURCE_IMPORT, + SOURCE_DISCOVERY, + SOURCE_IMPORT, ) EVENT_FLOW_DISCOVERED = 'config_entry_discovered' @@ -374,12 +383,15 @@ class ConfigEntries: if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: return None + source = result['source'] + if source is None: + source = SOURCE_USER entry = ConfigEntry( version=result['version'], domain=result['handler'], title=result['title'], data=result['data'], - source=result['source'], + source=source, ) self._entries.append(entry) await self._async_schedule_save() @@ -399,17 +411,22 @@ class ConfigEntries: return entry - async def _async_create_flow(self, handler, *, source, data): + async def _async_create_flow(self, handler_key, *, context, data): """Create a flow for specified handler. Handler key is the domain of the component that we want to setup. """ - component = getattr(self.hass.components, handler) - handler = HANDLERS.get(handler) + component = getattr(self.hass.components, handler_key) + handler = HANDLERS.get(handler_key) if handler is None: raise data_entry_flow.UnknownHandler + if context is not None: + source = context.get('source', SOURCE_USER) + else: + source = SOURCE_USER + # Make sure requirements and dependencies of component are resolved await async_process_deps_reqs( self.hass, self._hass_config, handler, component) @@ -424,7 +441,10 @@ class ConfigEntries: notification_id=DISCOVERY_NOTIFICATION_ID ) - return handler() + flow = handler() + flow.source = source + flow.init_step = source + return flow async def _async_schedule_save(self): """Save the entity registry to a file.""" diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index f010ada02f3..aee215dff80 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -8,10 +8,6 @@ from .exceptions import HomeAssistantError _LOGGER = logging.getLogger(__name__) -SOURCE_USER = 'user' -SOURCE_DISCOVERY = 'discovery' -SOURCE_IMPORT = 'import' - RESULT_TYPE_FORM = 'form' RESULT_TYPE_CREATE_ENTRY = 'create_entry' RESULT_TYPE_ABORT = 'abort' @@ -53,22 +49,17 @@ class FlowManager: 'source': flow.source, } for flow in self._progress.values()] - async def async_init(self, handler: Callable, *, source: str = SOURCE_USER, - data: str = None) -> Any: + async def async_init(self, handler: Callable, *, context: Dict = None, + data: Any = None) -> Any: """Start a configuration flow.""" - flow = await self._async_create_flow(handler, source=source, data=data) + flow = await self._async_create_flow( + handler, context=context, data=data) flow.hass = self.hass flow.handler = handler flow.flow_id = uuid.uuid4().hex - flow.source = source self._progress[flow.flow_id] = flow - if source == SOURCE_USER: - step = 'init' - else: - step = source - - return await self._async_handle_step(flow, step, data) + return await self._async_handle_step(flow, flow.init_step, data) async def async_configure( self, flow_id: str, user_input: str = None) -> Any: @@ -131,9 +122,12 @@ class FlowHandler: flow_id = None hass = None handler = None - source = SOURCE_USER + source = None cur_step = None + # Set by _async_create_flow callback + init_step = 'init' + # Set by developer VERSION = 1 diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index 6f51d9aca2c..e17d5071c6a 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -22,7 +22,7 @@ class DiscoveryFlowHandler(data_entry_flow.FlowHandler): self._title = title self._discovery_function = discovery_function - async def async_step_init(self, user_input=None): + async def async_step_user(self, user_input=None): """Handle a flow initialized by the user.""" if self._async_current_entries(): return self.async_abort( diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index 4f412eb58e7..378febf8f6d 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -2,7 +2,7 @@ import voluptuous as vol -from homeassistant import data_entry_flow +from homeassistant import data_entry_flow, config_entries from homeassistant.components.http import HomeAssistantView from homeassistant.components.http.data_validator import RequestDataValidator @@ -53,7 +53,8 @@ class FlowManagerIndexView(_BaseFlowManagerView): handler = data['handler'] try: - result = await self._flow_mgr.async_init(handler) + result = await self._flow_mgr.async_init( + handler, context={'source': config_entries.SOURCE_USER}) except data_entry_flow.UnknownHandler: return self.json_message('Invalid handler specified', 404) except data_entry_flow.UnknownStep: diff --git a/tests/common.py b/tests/common.py index 5567a431e58..3a2248d0d50 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,7 +12,7 @@ import logging import threading from contextlib import contextmanager -from homeassistant import auth, core as ha, data_entry_flow, config_entries +from homeassistant import auth, core as ha, config_entries from homeassistant.auth import ( models as auth_models, auth_store, providers as auth_providers) from homeassistant.setup import setup_component, async_setup_component @@ -509,7 +509,7 @@ class MockConfigEntry(config_entries.ConfigEntry): """Helper for creating config entries that adds some defaults.""" def __init__(self, *, domain='test', data=None, version=0, entry_id=None, - source=data_entry_flow.SOURCE_USER, title='Mock Title', + source=config_entries.SOURCE_USER, title='Mock Title', state=None): """Initialize a mock config entry.""" kwargs = { diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index 82c747da01c..f85d7df1a86 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -102,13 +102,13 @@ def test_initialize_flow(hass, client): """Test we can initialize a flow.""" class TestFlow(FlowHandler): @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): schema = OrderedDict() schema[vol.Required('username')] = str schema[vol.Required('password')] = str return self.async_show_form( - step_id='init', + step_id='user', data_schema=schema, description_placeholders={ 'url': 'https://example.com', @@ -130,7 +130,7 @@ def test_initialize_flow(hass, client): assert data == { 'type': 'form', 'handler': 'test', - 'step_id': 'init', + 'step_id': 'user', 'data_schema': [ { 'name': 'username', @@ -157,7 +157,7 @@ def test_abort(hass, client): """Test a flow that aborts.""" class TestFlow(FlowHandler): @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): return self.async_abort(reason='bla') with patch.dict(HANDLERS, {'test': TestFlow}): @@ -185,7 +185,7 @@ def test_create_account(hass, client): VERSION = 1 @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): return self.async_create_entry( title='Test Entry', data={'secret': 'account_token'} @@ -218,7 +218,7 @@ def test_two_step_flow(hass, client): VERSION = 1 @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): return self.async_show_form( step_id='account', data_schema=vol.Schema({ @@ -286,7 +286,7 @@ def test_get_progress_index(hass, client): with patch.dict(HANDLERS, {'test': TestFlow}): form = yield from hass.config_entries.flow.async_init( - 'test', source='hassio') + 'test', context={'source': 'hassio'}) resp = yield from client.get('/api/config/config_entries/flow') assert resp.status == 200 @@ -305,13 +305,13 @@ def test_get_progress_flow(hass, client): """Test we can query the API for same result as we get from init a flow.""" class TestFlow(FlowHandler): @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): schema = OrderedDict() schema[vol.Required('username')] = str schema[vol.Required('password')] = str return self.async_show_form( - step_id='init', + step_id='user', data_schema=schema, errors={ 'username': 'Should be unique.' diff --git a/tests/components/test_discovery.py b/tests/components/test_discovery.py index dd22c87cb18..8b997cb911c 100644 --- a/tests/components/test_discovery.py +++ b/tests/components/test_discovery.py @@ -5,7 +5,7 @@ from unittest.mock import patch, MagicMock import pytest -from homeassistant import data_entry_flow +from homeassistant import config_entries from homeassistant.bootstrap import async_setup_component from homeassistant.components import discovery from homeassistant.util.dt import utcnow @@ -175,5 +175,5 @@ async def test_discover_config_flow(hass): assert len(m_init.mock_calls) == 1 args, kwargs = m_init.mock_calls[0][1:] assert args == ('mock-component',) - assert kwargs['source'] == data_entry_flow.SOURCE_DISCOVERY + assert kwargs['context']['source'] == config_entries.SOURCE_DISCOVERY assert kwargs['data'] == discovery_info diff --git a/tests/helpers/test_config_entry_flow.py b/tests/helpers/test_config_entry_flow.py index 19185e165bc..46c58320d50 100644 --- a/tests/helpers/test_config_entry_flow.py +++ b/tests/helpers/test_config_entry_flow.py @@ -31,7 +31,7 @@ async def test_single_entry_allowed(hass, flow_conf): flow.hass = hass MockConfigEntry(domain='test').add_to_hass(hass) - result = await flow.async_step_init() + result = await flow.async_step_user() assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT assert result['reason'] == 'single_instance_allowed' @@ -42,7 +42,7 @@ async def test_user_no_devices_found(hass, flow_conf): flow = config_entries.HANDLERS['test']() flow.hass = hass - result = await flow.async_step_init() + result = await flow.async_step_user() assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT assert result['reason'] == 'no_devices_found' @@ -54,7 +54,7 @@ async def test_user_no_confirmation(hass, flow_conf): flow.hass = hass flow_conf['discovered'] = True - result = await flow.async_step_init() + result = await flow.async_step_user() assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY @@ -90,12 +90,12 @@ async def test_multiple_discoveries(hass, flow_conf): loader.set_component(hass, 'test', MockModule('test')) result = await hass.config_entries.flow.async_init( - 'test', source=data_entry_flow.SOURCE_DISCOVERY, data={}) + 'test', context={'source': config_entries.SOURCE_DISCOVERY}, data={}) assert result['type'] == data_entry_flow.RESULT_TYPE_FORM # Second discovery result = await hass.config_entries.flow.async_init( - 'test', source=data_entry_flow.SOURCE_DISCOVERY, data={}) + 'test', context={'source': config_entries.SOURCE_DISCOVERY}, data={}) assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT @@ -105,7 +105,7 @@ async def test_user_init_trumps_discovery(hass, flow_conf): # Discovery starts flow result = await hass.config_entries.flow.async_init( - 'test', source=data_entry_flow.SOURCE_DISCOVERY, data={}) + 'test', context={'source': config_entries.SOURCE_DISCOVERY}, data={}) assert result['type'] == data_entry_flow.RESULT_TYPE_FORM # User starts flow diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index d7a7ec4b82b..8ac4c642b0a 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -108,7 +108,7 @@ def test_add_entry_calls_setup_entry(hass, manager): VERSION = 1 @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): return self.async_create_entry( title='title', data={ @@ -162,7 +162,7 @@ async def test_saving_and_loading(hass): VERSION = 5 @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): return self.async_create_entry( title='Test Title', data={ @@ -177,7 +177,7 @@ async def test_saving_and_loading(hass): VERSION = 3 @asyncio.coroutine - def async_step_init(self, user_input=None): + def async_step_user(self, user_input=None): return self.async_create_entry( title='Test 2 Title', data={ @@ -266,7 +266,7 @@ async def test_discovery_notification(hass): with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): result = await hass.config_entries.flow.async_init( - 'test', source=data_entry_flow.SOURCE_DISCOVERY) + 'test', context={'source': config_entries.SOURCE_DISCOVERY}) await hass.async_block_till_done() state = hass.states.get('persistent_notification.config_entry_discovery') @@ -294,7 +294,7 @@ async def test_discovery_notification_not_created(hass): with patch.dict(config_entries.HANDLERS, {'test': TestFlow}): await hass.config_entries.flow.async_init( - 'test', source=data_entry_flow.SOURCE_DISCOVERY) + 'test', context={'source': config_entries.SOURCE_DISCOVERY}) await hass.async_block_till_done() state = hass.states.get('persistent_notification.config_entry_discovery') diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 894fd4d7194..dc10f3d8d1a 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -12,13 +12,18 @@ def manager(): handlers = Registry() entries = [] - async def async_create_flow(handler_name, *, source, data): + async def async_create_flow(handler_name, *, context, data): handler = handlers.get(handler_name) if handler is None: raise data_entry_flow.UnknownHandler - return handler() + flow = handler() + flow.init_step = context.get('init_step', 'init') \ + if context is not None else 'init' + flow.source = context.get('source') \ + if context is not None else 'user_input' + return flow async def async_add_entry(result): if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY): @@ -57,12 +62,12 @@ async def test_configure_two_steps(manager): class TestFlow(data_entry_flow.FlowHandler): VERSION = 1 - async def async_step_init(self, user_input=None): + async def async_step_first(self, user_input=None): if user_input is not None: self.init_data = user_input return await self.async_step_second() return self.async_show_form( - step_id='init', + step_id='first', data_schema=vol.Schema([str]) ) @@ -77,7 +82,7 @@ async def test_configure_two_steps(manager): data_schema=vol.Schema([str]) ) - form = await manager.async_init('test') + form = await manager.async_init('test', context={'init_step': 'first'}) with pytest.raises(vol.Invalid): form = await manager.async_configure( @@ -163,7 +168,7 @@ async def test_create_saves_data(manager): assert entry['handler'] == 'test' assert entry['title'] == 'Test Title' assert entry['data'] == 'Test Data' - assert entry['source'] == data_entry_flow.SOURCE_USER + assert entry['source'] == 'user_input' async def test_discovery_init_flow(manager): @@ -172,7 +177,7 @@ async def test_discovery_init_flow(manager): class TestFlow(data_entry_flow.FlowHandler): VERSION = 5 - async def async_step_discovery(self, info): + async def async_step_init(self, info): return self.async_create_entry(title=info['id'], data=info) data = { @@ -181,7 +186,7 @@ async def test_discovery_init_flow(manager): } await manager.async_init( - 'test', source=data_entry_flow.SOURCE_DISCOVERY, data=data) + 'test', context={'source': 'discovery'}, data=data) assert len(manager.async_progress()) == 0 assert len(manager.mock_created_entries) == 1 @@ -190,4 +195,4 @@ async def test_discovery_init_flow(manager): assert entry['handler'] == 'test' assert entry['title'] == 'hello' assert entry['data'] == data - assert entry['source'] == data_entry_flow.SOURCE_DISCOVERY + assert entry['source'] == 'discovery'