Extract config flow to own module (#13840)

* Extract config flow to own module

* Lint

* fix lint

* fix typo

* ConfigFlowHandler -> FlowHandler

* Rename to data_entry_flow
This commit is contained in:
Paulus Schoutsen 2018-04-13 10:14:53 -04:00 committed by GitHub
parent ddd2003629
commit 60508f7215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 428 additions and 388 deletions

View File

@ -3,7 +3,7 @@ import asyncio
import voluptuous as vol
from homeassistant import config_entries
from homeassistant import config_entries, data_entry_flow
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
@ -24,7 +24,7 @@ def async_setup(hass):
def _prepare_json(result):
"""Convert result for JSON."""
if result['type'] != config_entries.RESULT_TYPE_FORM:
if result['type'] != data_entry_flow.RESULT_TYPE_FORM:
return result
import voluptuous_serialize
@ -94,8 +94,8 @@ class ConfigManagerFlowIndexView(HomeAssistantView):
hass = request.app['hass']
return self.json([
flow for flow in hass.config_entries.flow.async_progress()
if flow['source'] != config_entries.SOURCE_USER])
flw for flw in hass.config_entries.flow.async_progress()
if flw['source'] != data_entry_flow.SOURCE_USER])
@RequestDataValidator(vol.Schema({
vol.Required('domain'): str,
@ -108,9 +108,9 @@ class ConfigManagerFlowIndexView(HomeAssistantView):
try:
result = yield from hass.config_entries.flow.async_init(
data['domain'])
except config_entries.UnknownHandler:
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except config_entries.UnknownStep:
except data_entry_flow.UnknownStep:
return self.json_message('Handler does not support init', 400)
result = _prepare_json(result)
@ -126,13 +126,13 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
@asyncio.coroutine
def get(self, request, flow_id):
"""Get the current state of a flow."""
"""Get the current state of a data_entry_flow."""
hass = request.app['hass']
try:
result = yield from hass.config_entries.flow.async_configure(
flow_id)
except config_entries.UnknownFlow:
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
result = _prepare_json(result)
@ -148,7 +148,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
try:
result = yield from hass.config_entries.flow.async_configure(
flow_id, data)
except config_entries.UnknownFlow:
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
except vol.Invalid:
return self.json_message('User input malformed', 400)
@ -164,7 +164,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
try:
hass.config_entries.flow.async_abort(flow_id)
except config_entries.UnknownFlow:
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
return self.json_message('Flow aborted')

View File

@ -8,7 +8,7 @@ import logging
import voluptuous as vol
from homeassistant import config_entries
from homeassistant import config_entries, data_entry_flow
from homeassistant.components.discovery import SERVICE_DECONZ
from homeassistant.const import (
CONF_API_KEY, CONF_HOST, CONF_PORT, EVENT_HOMEASSISTANT_STOP)
@ -191,7 +191,7 @@ async def async_request_configuration(hass, config, deconz_config):
@config_entries.HANDLERS.register(DOMAIN)
class DeconzFlowHandler(config_entries.ConfigFlowHandler):
class DeconzFlowHandler(data_entry_flow.FlowHandler):
"""Handle a deCONZ config flow."""
VERSION = 1

View File

@ -13,7 +13,7 @@ import os
import voluptuous as vol
from homeassistant import config_entries
from homeassistant import data_entry_flow
from homeassistant.core import callback
from homeassistant.const import EVENT_HOMEASSISTANT_START
import homeassistant.helpers.config_validation as cv
@ -119,7 +119,7 @@ async def async_setup(hass, config):
if service in CONFIG_ENTRY_HANDLERS:
await hass.config_entries.flow.async_init(
CONFIG_ENTRY_HANDLERS[service],
source=config_entries.SOURCE_DISCOVERY,
source=data_entry_flow.SOURCE_DISCOVERY,
data=info
)
return

View File

@ -6,7 +6,7 @@ import os
import async_timeout
import voluptuous as vol
from homeassistant import config_entries
from homeassistant import config_entries, data_entry_flow
from homeassistant.core import callback
from homeassistant.helpers import aiohttp_client
@ -41,7 +41,7 @@ def _find_username_from_config(hass, filename):
@config_entries.HANDLERS.register(DOMAIN)
class HueFlowHandler(config_entries.ConfigFlowHandler):
class HueFlowHandler(data_entry_flow.FlowHandler):
"""Handle a Hue config flow."""
VERSION = 1

View File

@ -27,7 +27,7 @@ At a minimum, each config flow will have to define a version number and the
'init' step.
@config_entries.HANDLERS.register(DOMAIN)
class ExampleConfigFlow(config_entries.ConfigFlowHandler):
class ExampleConfigFlow(config_entries.FlowHandler):
VERSION = 1
@ -117,6 +117,7 @@ import uuid
from .core import callback
from .exceptions import HomeAssistantError
from .data_entry_flow import FlowManager
from .setup import async_setup_component, async_process_deps_reqs
from .util.json import load_json, save_json
from .util.decorator import Registry
@ -130,17 +131,11 @@ FLOWS = [
'hue',
]
SOURCE_USER = 'user'
SOURCE_DISCOVERY = 'discovery'
PATH_CONFIG = '.config_entries.json'
SAVE_DELAY = 1
RESULT_TYPE_FORM = 'form'
RESULT_TYPE_CREATE_ENTRY = 'create_entry'
RESULT_TYPE_ABORT = 'abort'
ENTRY_STATE_LOADED = 'loaded'
ENTRY_STATE_SETUP_ERROR = 'setup_error'
ENTRY_STATE_NOT_LOADED = 'not_loaded'
@ -251,18 +246,6 @@ class UnknownEntry(ConfigError):
"""Unknown entry specified."""
class UnknownHandler(ConfigError):
"""Unknown handler specified."""
class UnknownFlow(ConfigError):
"""Uknown flow specified."""
class UnknownStep(ConfigError):
"""Unknown step specified."""
class ConfigEntries:
"""Manage the configuration entries.
@ -272,7 +255,8 @@ class ConfigEntries:
def __init__(self, hass, hass_config):
"""Initialize the entry manager."""
self.hass = hass
self.flow = FlowManager(hass, hass_config, self._async_add_entry)
self.flow = FlowManager(hass, HANDLERS, self._async_missing_handler,
self._async_save_entry)
self._hass_config = hass_config
self._entries = None
self._sched_save = None
@ -357,8 +341,15 @@ class ConfigEntries:
await entry.async_unload(
self.hass, component=getattr(self.hass.components, component))
async def _async_add_entry(self, entry):
async def _async_save_entry(self, result):
"""Add an entry."""
entry = ConfigEntry(
version=result['version'],
domain=result['domain'],
title=result['title'],
data=result['data'],
source=result['source'],
)
self._entries.append(entry)
self._async_schedule_save()
@ -371,6 +362,18 @@ class ConfigEntries:
await async_setup_component(
self.hass, entry.domain, self._hass_config)
async def _async_missing_handler(self, domain):
"""Called when a flow handler is not loaded."""
# This will load the component and thus register the handler
component = getattr(self.hass.components, domain)
if domain not in HANDLERS:
return
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(
self.hass, self._hass_config, domain, component)
@callback
def _async_schedule_save(self):
"""Schedule saving the entity registry."""
@ -388,157 +391,3 @@ class ConfigEntries:
await self.hass.async_add_job(
save_json, self.hass.config.path(PATH_CONFIG), data)
class FlowManager:
"""Manage all the config flows that are in progress."""
def __init__(self, hass, hass_config, async_add_entry):
"""Initialize the flow manager."""
self.hass = hass
self._hass_config = hass_config
self._progress = {}
self._async_add_entry = async_add_entry
@callback
def async_progress(self):
"""Return the flows in progress."""
return [{
'flow_id': flow.flow_id,
'domain': flow.domain,
'source': flow.source,
} for flow in self._progress.values()]
async def async_init(self, domain, *, source=SOURCE_USER, data=None):
"""Start a configuration flow."""
handler = HANDLERS.get(domain)
if handler is None:
# This will load the component and thus register the handler
component = getattr(self.hass.components, domain)
handler = HANDLERS.get(domain)
if handler is None:
raise UnknownHandler
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(
self.hass, self._hass_config, domain, component)
flow_id = uuid.uuid4().hex
flow = self._progress[flow_id] = handler()
flow.hass = self.hass
flow.domain = domain
flow.flow_id = flow_id
flow.source = source
if source == SOURCE_USER:
step = 'init'
else:
step = source
return await self._async_handle_step(flow, step, data)
async def async_configure(self, flow_id, user_input=None):
"""Start or continue a configuration flow."""
flow = self._progress.get(flow_id)
if flow is None:
raise UnknownFlow
step_id, data_schema = flow.cur_step
if data_schema is not None and user_input is not None:
user_input = data_schema(user_input)
return await self._async_handle_step(
flow, step_id, user_input)
@callback
def async_abort(self, flow_id):
"""Abort a flow."""
if self._progress.pop(flow_id, None) is None:
raise UnknownFlow
async def _async_handle_step(self, flow, step_id, user_input):
"""Handle a step of a flow."""
method = "async_step_{}".format(step_id)
if not hasattr(flow, method):
self._progress.pop(flow.flow_id)
raise UnknownStep("Handler {} doesn't support step {}".format(
flow.__class__.__name__, step_id))
result = await getattr(flow, method)(user_input)
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
RESULT_TYPE_ABORT):
raise ValueError(
'Handler returned incorrect type: {}'.format(result['type']))
if result['type'] == RESULT_TYPE_FORM:
flow.cur_step = (result['step_id'], result['data_schema'])
return result
# Abort and Success results both finish the flow
self._progress.pop(flow.flow_id)
if result['type'] == RESULT_TYPE_ABORT:
return result
entry = ConfigEntry(
version=flow.VERSION,
domain=flow.domain,
title=result['title'],
data=result.pop('data'),
source=flow.source
)
await self._async_add_entry(entry)
return result
class ConfigFlowHandler:
"""Handle the configuration flow of a component."""
# Set by flow manager
flow_id = None
hass = None
domain = None
source = SOURCE_USER
cur_step = None
# Set by dev
# VERSION
@callback
def async_show_form(self, *, step_id, data_schema=None, errors=None):
"""Return the definition of a form to gather user input."""
return {
'type': RESULT_TYPE_FORM,
'flow_id': self.flow_id,
'domain': self.domain,
'step_id': step_id,
'data_schema': data_schema,
'errors': errors,
}
@callback
def async_create_entry(self, *, title, data):
"""Finish config flow and create a config entry."""
return {
'type': RESULT_TYPE_CREATE_ENTRY,
'flow_id': self.flow_id,
'domain': self.domain,
'title': title,
'data': data,
}
@callback
def async_abort(self, *, reason):
"""Abort the config flow."""
return {
'type': RESULT_TYPE_ABORT,
'flow_id': self.flow_id,
'domain': self.domain,
'reason': reason
}

View File

@ -0,0 +1,180 @@
"""Classes to help gather user submissions."""
import logging
import uuid
from .core import callback
from .exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__)
SOURCE_USER = 'user'
SOURCE_DISCOVERY = 'discovery'
RESULT_TYPE_FORM = 'form'
RESULT_TYPE_CREATE_ENTRY = 'create_entry'
RESULT_TYPE_ABORT = 'abort'
class FlowError(HomeAssistantError):
"""Error while configuring an account."""
class UnknownHandler(FlowError):
"""Unknown handler specified."""
class UnknownFlow(FlowError):
"""Uknown flow specified."""
class UnknownStep(FlowError):
"""Unknown step specified."""
class FlowManager:
"""Manage all the flows that are in progress."""
def __init__(self, hass, handlers, async_missing_handler,
async_save_entry):
"""Initialize the flow manager."""
self.hass = hass
self._handlers = handlers
self._progress = {}
self._async_missing_handler = async_missing_handler
self._async_save_entry = async_save_entry
@callback
def async_progress(self):
"""Return the flows in progress."""
return [{
'flow_id': flow.flow_id,
'domain': flow.domain,
'source': flow.source,
} for flow in self._progress.values()]
async def async_init(self, domain, *, source=SOURCE_USER, data=None):
"""Start a configuration flow."""
handler = self._handlers.get(domain)
if handler is None:
await self._async_missing_handler(domain)
handler = self._handlers.get(domain)
if handler is None:
raise UnknownHandler
flow_id = uuid.uuid4().hex
flow = self._progress[flow_id] = handler()
flow.hass = self.hass
flow.domain = domain
flow.flow_id = flow_id
flow.source = source
if source == SOURCE_USER:
step = 'init'
else:
step = source
return await self._async_handle_step(flow, step, data)
async def async_configure(self, flow_id, user_input=None):
"""Start or continue a configuration flow."""
flow = self._progress.get(flow_id)
if flow is None:
raise UnknownFlow
step_id, data_schema = flow.cur_step
if data_schema is not None and user_input is not None:
user_input = data_schema(user_input)
return await self._async_handle_step(
flow, step_id, user_input)
@callback
def async_abort(self, flow_id):
"""Abort a flow."""
if self._progress.pop(flow_id, None) is None:
raise UnknownFlow
async def _async_handle_step(self, flow, step_id, user_input):
"""Handle a step of a flow."""
method = "async_step_{}".format(step_id)
if not hasattr(flow, method):
self._progress.pop(flow.flow_id)
raise UnknownStep("Handler {} doesn't support step {}".format(
flow.__class__.__name__, step_id))
result = await getattr(flow, method)(user_input)
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
RESULT_TYPE_ABORT):
raise ValueError(
'Handler returned incorrect type: {}'.format(result['type']))
if result['type'] == RESULT_TYPE_FORM:
flow.cur_step = (result['step_id'], result['data_schema'])
return result
# Abort and Success results both finish the flow
self._progress.pop(flow.flow_id)
if result['type'] == RESULT_TYPE_ABORT:
return result
# We pass a copy of the result because we're going to mutate our
# version afterwards and don't want to cause unexpected bugs.
await self._async_save_entry(dict(result))
result.pop('data')
return result
class FlowHandler:
"""Handle the configuration flow of a component."""
# Set by flow manager
flow_id = None
hass = None
domain = None
source = SOURCE_USER
cur_step = None
# Set by developer
VERSION = 1
@callback
def async_show_form(self, *, step_id, data_schema=None, errors=None):
"""Return the definition of a form to gather user input."""
return {
'type': RESULT_TYPE_FORM,
'flow_id': self.flow_id,
'domain': self.domain,
'step_id': step_id,
'data_schema': data_schema,
'errors': errors,
}
@callback
def async_create_entry(self, *, title, data):
"""Finish config flow and create a config entry."""
return {
'version': self.VERSION,
'type': RESULT_TYPE_CREATE_ENTRY,
'flow_id': self.flow_id,
'domain': self.domain,
'title': title,
'data': data,
'source': self.source,
}
@callback
def async_abort(self, *, reason):
"""Abort the config flow."""
return {
'type': RESULT_TYPE_ABORT,
'flow_id': self.flow_id,
'domain': self.domain,
'reason': reason
}

View File

@ -10,7 +10,7 @@ import logging
import threading
from contextlib import contextmanager
from homeassistant import core as ha, loader, config_entries
from homeassistant import core as ha, loader, data_entry_flow, config_entries
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
from homeassistant.helpers import (
@ -455,7 +455,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
"""Helper for creating config entries that adds some defaults."""
def __init__(self, *, domain='test', data=None, version=0, entry_id=None,
source=config_entries.SOURCE_USER, title='Mock Title',
source=data_entry_flow.SOURCE_USER, title='Mock Title',
state=None):
"""Initialize a mock config entry."""
kwargs = {

View File

@ -8,7 +8,8 @@ import pytest
import voluptuous as vol
from homeassistant import config_entries as core_ce
from homeassistant.config_entries import ConfigFlowHandler, HANDLERS
from homeassistant.config_entries import HANDLERS
from homeassistant.data_entry_flow import FlowHandler
from homeassistant.setup import async_setup_component
from homeassistant.components.config import config_entries
from homeassistant.loader import set_component
@ -93,7 +94,7 @@ def test_available_flows(hass, client):
@asyncio.coroutine
def test_initialize_flow(hass, client):
"""Test we can initialize a flow."""
class TestFlow(ConfigFlowHandler):
class TestFlow(FlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
schema = OrderedDict()
@ -142,7 +143,7 @@ def test_initialize_flow(hass, client):
@asyncio.coroutine
def test_abort(hass, client):
"""Test a flow that aborts."""
class TestFlow(ConfigFlowHandler):
class TestFlow(FlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
return self.async_abort(reason='bla')
@ -167,7 +168,7 @@ def test_create_account(hass, client):
set_component(
'test', MockModule('test', async_setup_entry=mock_coro_func(True)))
class TestFlow(ConfigFlowHandler):
class TestFlow(FlowHandler):
VERSION = 1
@asyncio.coroutine
@ -187,7 +188,9 @@ def test_create_account(hass, client):
assert data == {
'domain': 'test',
'title': 'Test Entry',
'type': 'create_entry'
'type': 'create_entry',
'source': 'user',
'version': 1,
}
@ -197,7 +200,7 @@ def test_two_step_flow(hass, client):
set_component(
'test', MockModule('test', async_setup_entry=mock_coro_func(True)))
class TestFlow(ConfigFlowHandler):
class TestFlow(FlowHandler):
VERSION = 1
@asyncio.coroutine
@ -245,13 +248,15 @@ def test_two_step_flow(hass, client):
'domain': 'test',
'type': 'create_entry',
'title': 'user-title',
'version': 1,
'source': 'user',
}
@asyncio.coroutine
def test_get_progress_index(hass, client):
"""Test querying for the flows that are in progress."""
class TestFlow(ConfigFlowHandler):
class TestFlow(FlowHandler):
VERSION = 5
@asyncio.coroutine
@ -283,7 +288,7 @@ def test_get_progress_index(hass, client):
@asyncio.coroutine
def test_get_progress_flow(hass, client):
"""Test we can query the API for same result as we get from init a flow."""
class TestFlow(ConfigFlowHandler):
class TestFlow(FlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
schema = OrderedDict()

View File

@ -5,7 +5,7 @@ from unittest.mock import patch, MagicMock
import pytest
from homeassistant import config_entries
from homeassistant import data_entry_flow
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import discovery
from homeassistant.util.dt import utcnow
@ -174,5 +174,5 @@ async def test_discover_config_flow(hass):
assert len(m_init.mock_calls) == 1
args, kwargs = m_init.mock_calls[0][1:]
assert args == ('mock-component',)
assert kwargs['source'] == config_entries.SOURCE_DISCOVERY
assert kwargs['source'] == data_entry_flow.SOURCE_DISCOVERY
assert kwargs['data'] == discovery_info

View File

@ -3,9 +3,8 @@ import asyncio
from unittest.mock import MagicMock, patch, mock_open
import pytest
import voluptuous as vol
from homeassistant import config_entries, loader
from homeassistant import config_entries, loader, data_entry_flow
from homeassistant.setup import async_setup_component
from tests.common import MockModule, mock_coro, MockConfigEntry
@ -100,7 +99,7 @@ def test_add_entry_calls_setup_entry(hass, manager):
'comp',
MockModule('comp', async_setup_entry=mock_setup_entry))
class TestFlow(config_entries.ConfigFlowHandler):
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
@ -112,7 +111,7 @@ def test_add_entry_calls_setup_entry(hass, manager):
'token': 'supersecret'
})
with patch.dict(config_entries.HANDLERS, {'comp': TestFlow}):
with patch.dict(config_entries.HANDLERS, {'comp': TestFlow, 'beer': 5}):
yield from manager.flow.async_init('comp')
yield from hass.async_block_till_done()
@ -152,7 +151,7 @@ def test_domains_gets_uniques(manager):
@asyncio.coroutine
def test_saving_and_loading(hass):
"""Test that we're saving and loading correctly."""
class TestFlow(config_entries.ConfigFlowHandler):
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
@asyncio.coroutine
@ -167,7 +166,7 @@ def test_saving_and_loading(hass):
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
yield from hass.config_entries.flow.async_init('test')
class Test2Flow(config_entries.ConfigFlowHandler):
class Test2Flow(data_entry_flow.FlowHandler):
VERSION = 3
@asyncio.coroutine
@ -212,185 +211,6 @@ def test_saving_and_loading(hass):
assert orig.source == loaded.source
#######################
# FLOW MANAGER TESTS #
#######################
@asyncio.coroutine
def test_configure_reuses_handler_instance(manager):
"""Test that we reuse instances."""
class TestFlow(config_entries.ConfigFlowHandler):
handle_count = 0
@asyncio.coroutine
def async_step_init(self, user_input=None):
self.handle_count += 1
return self.async_show_form(
errors={'base': str(self.handle_count)},
step_id='init')
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
form = yield from manager.flow.async_init('test')
assert form['errors']['base'] == '1'
form = yield from manager.flow.async_configure(form['flow_id'])
assert form['errors']['base'] == '2'
assert len(manager.flow.async_progress()) == 1
assert len(manager.async_entries()) == 0
@asyncio.coroutine
def test_configure_two_steps(manager):
"""Test that we reuse instances."""
class TestFlow(config_entries.ConfigFlowHandler):
VERSION = 1
@asyncio.coroutine
def async_step_init(self, user_input=None):
if user_input is not None:
self.init_data = user_input
return self.async_step_second()
return self.async_show_form(
step_id='init',
data_schema=vol.Schema([str])
)
@asyncio.coroutine
def async_step_second(self, user_input=None):
if user_input is not None:
return self.async_create_entry(
title='Test Entry',
data=self.init_data + user_input
)
return self.async_show_form(
step_id='second',
data_schema=vol.Schema([str])
)
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
form = yield from manager.flow.async_init('test')
with pytest.raises(vol.Invalid):
form = yield from manager.flow.async_configure(
form['flow_id'], 'INCORRECT-DATA')
form = yield from manager.flow.async_configure(
form['flow_id'], ['INIT-DATA'])
form = yield from manager.flow.async_configure(
form['flow_id'], ['SECOND-DATA'])
assert form['type'] == config_entries.RESULT_TYPE_CREATE_ENTRY
assert len(manager.flow.async_progress()) == 0
assert len(manager.async_entries()) == 1
entry = manager.async_entries()[0]
assert entry.domain == 'test'
assert entry.data == ['INIT-DATA', 'SECOND-DATA']
@asyncio.coroutine
def test_show_form(manager):
"""Test that abort removes the flow from progress."""
schema = vol.Schema({
vol.Required('username'): str,
vol.Required('password'): str
})
class TestFlow(config_entries.ConfigFlowHandler):
@asyncio.coroutine
def async_step_init(self, user_input=None):
return self.async_show_form(
step_id='init',
data_schema=schema,
errors={
'username': 'Should be unique.'
}
)
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
form = yield from manager.flow.async_init('test')
assert form['type'] == 'form'
assert form['data_schema'] is schema
assert form['errors'] == {
'username': 'Should be unique.'
}
@asyncio.coroutine
def test_abort_removes_instance(manager):
"""Test that abort removes the flow from progress."""
class TestFlow(config_entries.ConfigFlowHandler):
is_new = True
@asyncio.coroutine
def async_step_init(self, user_input=None):
old = self.is_new
self.is_new = False
return self.async_abort(reason=str(old))
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
form = yield from manager.flow.async_init('test')
assert form['reason'] == 'True'
assert len(manager.flow.async_progress()) == 0
assert len(manager.async_entries()) == 0
form = yield from manager.flow.async_init('test')
assert form['reason'] == 'True'
assert len(manager.flow.async_progress()) == 0
assert len(manager.async_entries()) == 0
@asyncio.coroutine
def test_create_saves_data(manager):
"""Test creating a config entry."""
class TestFlow(config_entries.ConfigFlowHandler):
VERSION = 5
@asyncio.coroutine
def async_step_init(self, user_input=None):
return self.async_create_entry(
title='Test Title',
data='Test Data'
)
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
yield from manager.flow.async_init('test')
assert len(manager.flow.async_progress()) == 0
assert len(manager.async_entries()) == 1
entry = manager.async_entries()[0]
assert entry.version == 5
assert entry.domain == 'test'
assert entry.title == 'Test Title'
assert entry.data == 'Test Data'
assert entry.source == config_entries.SOURCE_USER
@asyncio.coroutine
def test_discovery_init_flow(manager):
"""Test a flow initialized by discovery."""
class TestFlow(config_entries.ConfigFlowHandler):
VERSION = 5
@asyncio.coroutine
def async_step_discovery(self, info):
return self.async_create_entry(title=info['id'], data=info)
data = {
'id': 'hello',
'token': 'secret'
}
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
yield from manager.flow.async_init(
'test', source=config_entries.SOURCE_DISCOVERY, data=data)
assert len(manager.flow.async_progress()) == 0
assert len(manager.async_entries()) == 1
entry = manager.async_entries()[0]
assert entry.version == 5
assert entry.domain == 'test'
assert entry.title == 'hello'
assert entry.data == data
assert entry.source == config_entries.SOURCE_DISCOVERY
async def test_forward_entry_sets_up_component(hass):
"""Test we setup the component entry is forwarded to."""
entry = MockConfigEntry(domain='original')

View File

@ -0,0 +1,186 @@
"""Test the flow classes."""
import pytest
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.util.decorator import Registry
from tests.common import mock_coro
@pytest.fixture
def manager():
"""Return a flow manager."""
handlers = Registry()
entries = []
async def async_add_entry(result):
entries.append(result)
manager = data_entry_flow.FlowManager(
None, handlers, mock_coro, async_add_entry)
manager.mock_created_entries = entries
manager.mock_reg_handler = handlers.register
return manager
async def test_configure_reuses_handler_instance(manager):
"""Test that we reuse instances."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
handle_count = 0
async def async_step_init(self, user_input=None):
self.handle_count += 1
return self.async_show_form(
errors={'base': str(self.handle_count)},
step_id='init')
form = await manager.async_init('test')
assert form['errors']['base'] == '1'
form = await manager.async_configure(form['flow_id'])
assert form['errors']['base'] == '2'
assert len(manager.async_progress()) == 1
assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager):
"""Test that we reuse instances."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, user_input=None):
if user_input is not None:
self.init_data = user_input
return await self.async_step_second()
return self.async_show_form(
step_id='init',
data_schema=vol.Schema([str])
)
async def async_step_second(self, user_input=None):
if user_input is not None:
return self.async_create_entry(
title='Test Entry',
data=self.init_data + user_input
)
return self.async_show_form(
step_id='second',
data_schema=vol.Schema([str])
)
form = await manager.async_init('test')
with pytest.raises(vol.Invalid):
form = await manager.async_configure(
form['flow_id'], 'INCORRECT-DATA')
form = await manager.async_configure(
form['flow_id'], ['INIT-DATA'])
form = await manager.async_configure(
form['flow_id'], ['SECOND-DATA'])
assert form['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
result = manager.mock_created_entries[0]
assert result['domain'] == 'test'
assert result['data'] == ['INIT-DATA', 'SECOND-DATA']
async def test_show_form(manager):
"""Test that abort removes the flow from progress."""
schema = vol.Schema({
vol.Required('username'): str,
vol.Required('password'): str
})
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_show_form(
step_id='init',
data_schema=schema,
errors={
'username': 'Should be unique.'
}
)
form = await manager.async_init('test')
assert form['type'] == 'form'
assert form['data_schema'] is schema
assert form['errors'] == {
'username': 'Should be unique.'
}
async def test_abort_removes_instance(manager):
"""Test that abort removes the flow from progress."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
is_new = True
async def async_step_init(self, user_input=None):
old = self.is_new
self.is_new = False
return self.async_abort(reason=str(old))
form = await manager.async_init('test')
assert form['reason'] == 'True'
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
form = await manager.async_init('test')
assert form['reason'] == 'True'
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
async def test_create_saves_data(manager):
"""Test creating a config entry."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_create_entry(
title='Test Title',
data='Test Data'
)
await manager.async_init('test')
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0]
assert entry['version'] == 5
assert entry['domain'] == 'test'
assert entry['title'] == 'Test Title'
assert entry['data'] == 'Test Data'
assert entry['source'] == data_entry_flow.SOURCE_USER
async def test_discovery_init_flow(manager):
"""Test a flow initialized by discovery."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_discovery(self, info):
return self.async_create_entry(title=info['id'], data=info)
data = {
'id': 'hello',
'token': 'secret'
}
await manager.async_init(
'test', source=data_entry_flow.SOURCE_DISCOVERY, data=data)
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0]
assert entry['version'] == 5
assert entry['domain'] == 'test'
assert entry['title'] == 'hello'
assert entry['data'] == data
assert entry['source'] == data_entry_flow.SOURCE_DISCOVERY