mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Extract config flow to own module (#13840)
* Extract config flow to own module * Lint * fix lint * fix typo * ConfigFlowHandler -> FlowHandler * Rename to data_entry_flow
This commit is contained in:
parent
ddd2003629
commit
60508f7215
@ -3,7 +3,7 @@ import asyncio
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
|
||||
@ -24,7 +24,7 @@ def async_setup(hass):
|
||||
|
||||
def _prepare_json(result):
|
||||
"""Convert result for JSON."""
|
||||
if result['type'] != config_entries.RESULT_TYPE_FORM:
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_FORM:
|
||||
return result
|
||||
|
||||
import voluptuous_serialize
|
||||
@ -94,8 +94,8 @@ class ConfigManagerFlowIndexView(HomeAssistantView):
|
||||
hass = request.app['hass']
|
||||
|
||||
return self.json([
|
||||
flow for flow in hass.config_entries.flow.async_progress()
|
||||
if flow['source'] != config_entries.SOURCE_USER])
|
||||
flw for flw in hass.config_entries.flow.async_progress()
|
||||
if flw['source'] != data_entry_flow.SOURCE_USER])
|
||||
|
||||
@RequestDataValidator(vol.Schema({
|
||||
vol.Required('domain'): str,
|
||||
@ -108,9 +108,9 @@ class ConfigManagerFlowIndexView(HomeAssistantView):
|
||||
try:
|
||||
result = yield from hass.config_entries.flow.async_init(
|
||||
data['domain'])
|
||||
except config_entries.UnknownHandler:
|
||||
except data_entry_flow.UnknownHandler:
|
||||
return self.json_message('Invalid handler specified', 404)
|
||||
except config_entries.UnknownStep:
|
||||
except data_entry_flow.UnknownStep:
|
||||
return self.json_message('Handler does not support init', 400)
|
||||
|
||||
result = _prepare_json(result)
|
||||
@ -126,13 +126,13 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request, flow_id):
|
||||
"""Get the current state of a flow."""
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
hass = request.app['hass']
|
||||
|
||||
try:
|
||||
result = yield from hass.config_entries.flow.async_configure(
|
||||
flow_id)
|
||||
except config_entries.UnknownFlow:
|
||||
except data_entry_flow.UnknownFlow:
|
||||
return self.json_message('Invalid flow specified', 404)
|
||||
|
||||
result = _prepare_json(result)
|
||||
@ -148,7 +148,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
|
||||
try:
|
||||
result = yield from hass.config_entries.flow.async_configure(
|
||||
flow_id, data)
|
||||
except config_entries.UnknownFlow:
|
||||
except data_entry_flow.UnknownFlow:
|
||||
return self.json_message('Invalid flow specified', 404)
|
||||
except vol.Invalid:
|
||||
return self.json_message('User input malformed', 400)
|
||||
@ -164,7 +164,7 @@ class ConfigManagerFlowResourceView(HomeAssistantView):
|
||||
|
||||
try:
|
||||
hass.config_entries.flow.async_abort(flow_id)
|
||||
except config_entries.UnknownFlow:
|
||||
except data_entry_flow.UnknownFlow:
|
||||
return self.json_message('Invalid flow specified', 404)
|
||||
|
||||
return self.json_message('Flow aborted')
|
||||
|
@ -8,7 +8,7 @@ import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components.discovery import SERVICE_DECONZ
|
||||
from homeassistant.const import (
|
||||
CONF_API_KEY, CONF_HOST, CONF_PORT, EVENT_HOMEASSISTANT_STOP)
|
||||
@ -191,7 +191,7 @@ async def async_request_configuration(hass, config, deconz_config):
|
||||
|
||||
|
||||
@config_entries.HANDLERS.register(DOMAIN)
|
||||
class DeconzFlowHandler(config_entries.ConfigFlowHandler):
|
||||
class DeconzFlowHandler(data_entry_flow.FlowHandler):
|
||||
"""Handle a deCONZ config flow."""
|
||||
|
||||
VERSION = 1
|
||||
|
@ -13,7 +13,7 @@ import os
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
@ -119,7 +119,7 @@ async def async_setup(hass, config):
|
||||
if service in CONFIG_ENTRY_HANDLERS:
|
||||
await hass.config_entries.flow.async_init(
|
||||
CONFIG_ENTRY_HANDLERS[service],
|
||||
source=config_entries.SOURCE_DISCOVERY,
|
||||
source=data_entry_flow.SOURCE_DISCOVERY,
|
||||
data=info
|
||||
)
|
||||
return
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
import async_timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import aiohttp_client
|
||||
|
||||
@ -41,7 +41,7 @@ def _find_username_from_config(hass, filename):
|
||||
|
||||
|
||||
@config_entries.HANDLERS.register(DOMAIN)
|
||||
class HueFlowHandler(config_entries.ConfigFlowHandler):
|
||||
class HueFlowHandler(data_entry_flow.FlowHandler):
|
||||
"""Handle a Hue config flow."""
|
||||
|
||||
VERSION = 1
|
||||
|
@ -27,7 +27,7 @@ At a minimum, each config flow will have to define a version number and the
|
||||
'init' step.
|
||||
|
||||
@config_entries.HANDLERS.register(DOMAIN)
|
||||
class ExampleConfigFlow(config_entries.ConfigFlowHandler):
|
||||
class ExampleConfigFlow(config_entries.FlowHandler):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
@ -117,6 +117,7 @@ import uuid
|
||||
|
||||
from .core import callback
|
||||
from .exceptions import HomeAssistantError
|
||||
from .data_entry_flow import FlowManager
|
||||
from .setup import async_setup_component, async_process_deps_reqs
|
||||
from .util.json import load_json, save_json
|
||||
from .util.decorator import Registry
|
||||
@ -130,17 +131,11 @@ FLOWS = [
|
||||
'hue',
|
||||
]
|
||||
|
||||
SOURCE_USER = 'user'
|
||||
SOURCE_DISCOVERY = 'discovery'
|
||||
|
||||
PATH_CONFIG = '.config_entries.json'
|
||||
|
||||
SAVE_DELAY = 1
|
||||
|
||||
RESULT_TYPE_FORM = 'form'
|
||||
RESULT_TYPE_CREATE_ENTRY = 'create_entry'
|
||||
RESULT_TYPE_ABORT = 'abort'
|
||||
|
||||
ENTRY_STATE_LOADED = 'loaded'
|
||||
ENTRY_STATE_SETUP_ERROR = 'setup_error'
|
||||
ENTRY_STATE_NOT_LOADED = 'not_loaded'
|
||||
@ -251,18 +246,6 @@ class UnknownEntry(ConfigError):
|
||||
"""Unknown entry specified."""
|
||||
|
||||
|
||||
class UnknownHandler(ConfigError):
|
||||
"""Unknown handler specified."""
|
||||
|
||||
|
||||
class UnknownFlow(ConfigError):
|
||||
"""Uknown flow specified."""
|
||||
|
||||
|
||||
class UnknownStep(ConfigError):
|
||||
"""Unknown step specified."""
|
||||
|
||||
|
||||
class ConfigEntries:
|
||||
"""Manage the configuration entries.
|
||||
|
||||
@ -272,7 +255,8 @@ class ConfigEntries:
|
||||
def __init__(self, hass, hass_config):
|
||||
"""Initialize the entry manager."""
|
||||
self.hass = hass
|
||||
self.flow = FlowManager(hass, hass_config, self._async_add_entry)
|
||||
self.flow = FlowManager(hass, HANDLERS, self._async_missing_handler,
|
||||
self._async_save_entry)
|
||||
self._hass_config = hass_config
|
||||
self._entries = None
|
||||
self._sched_save = None
|
||||
@ -357,8 +341,15 @@ class ConfigEntries:
|
||||
await entry.async_unload(
|
||||
self.hass, component=getattr(self.hass.components, component))
|
||||
|
||||
async def _async_add_entry(self, entry):
|
||||
async def _async_save_entry(self, result):
|
||||
"""Add an entry."""
|
||||
entry = ConfigEntry(
|
||||
version=result['version'],
|
||||
domain=result['domain'],
|
||||
title=result['title'],
|
||||
data=result['data'],
|
||||
source=result['source'],
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._async_schedule_save()
|
||||
|
||||
@ -371,6 +362,18 @@ class ConfigEntries:
|
||||
await async_setup_component(
|
||||
self.hass, entry.domain, self._hass_config)
|
||||
|
||||
async def _async_missing_handler(self, domain):
|
||||
"""Called when a flow handler is not loaded."""
|
||||
# This will load the component and thus register the handler
|
||||
component = getattr(self.hass.components, domain)
|
||||
|
||||
if domain not in HANDLERS:
|
||||
return
|
||||
|
||||
# Make sure requirements and dependencies of component are resolved
|
||||
await async_process_deps_reqs(
|
||||
self.hass, self._hass_config, domain, component)
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self):
|
||||
"""Schedule saving the entity registry."""
|
||||
@ -388,157 +391,3 @@ class ConfigEntries:
|
||||
|
||||
await self.hass.async_add_job(
|
||||
save_json, self.hass.config.path(PATH_CONFIG), data)
|
||||
|
||||
|
||||
class FlowManager:
|
||||
"""Manage all the config flows that are in progress."""
|
||||
|
||||
def __init__(self, hass, hass_config, async_add_entry):
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._hass_config = hass_config
|
||||
self._progress = {}
|
||||
self._async_add_entry = async_add_entry
|
||||
|
||||
@callback
|
||||
def async_progress(self):
|
||||
"""Return the flows in progress."""
|
||||
return [{
|
||||
'flow_id': flow.flow_id,
|
||||
'domain': flow.domain,
|
||||
'source': flow.source,
|
||||
} for flow in self._progress.values()]
|
||||
|
||||
async def async_init(self, domain, *, source=SOURCE_USER, data=None):
|
||||
"""Start a configuration flow."""
|
||||
handler = HANDLERS.get(domain)
|
||||
|
||||
if handler is None:
|
||||
# This will load the component and thus register the handler
|
||||
component = getattr(self.hass.components, domain)
|
||||
handler = HANDLERS.get(domain)
|
||||
|
||||
if handler is None:
|
||||
raise UnknownHandler
|
||||
|
||||
# Make sure requirements and dependencies of component are resolved
|
||||
await async_process_deps_reqs(
|
||||
self.hass, self._hass_config, domain, component)
|
||||
|
||||
flow_id = uuid.uuid4().hex
|
||||
flow = self._progress[flow_id] = handler()
|
||||
flow.hass = self.hass
|
||||
flow.domain = domain
|
||||
flow.flow_id = flow_id
|
||||
flow.source = source
|
||||
|
||||
if source == SOURCE_USER:
|
||||
step = 'init'
|
||||
else:
|
||||
step = source
|
||||
|
||||
return await self._async_handle_step(flow, step, data)
|
||||
|
||||
async def async_configure(self, flow_id, user_input=None):
|
||||
"""Start or continue a configuration flow."""
|
||||
flow = self._progress.get(flow_id)
|
||||
|
||||
if flow is None:
|
||||
raise UnknownFlow
|
||||
|
||||
step_id, data_schema = flow.cur_step
|
||||
|
||||
if data_schema is not None and user_input is not None:
|
||||
user_input = data_schema(user_input)
|
||||
|
||||
return await self._async_handle_step(
|
||||
flow, step_id, user_input)
|
||||
|
||||
@callback
|
||||
def async_abort(self, flow_id):
|
||||
"""Abort a flow."""
|
||||
if self._progress.pop(flow_id, None) is None:
|
||||
raise UnknownFlow
|
||||
|
||||
async def _async_handle_step(self, flow, step_id, user_input):
|
||||
"""Handle a step of a flow."""
|
||||
method = "async_step_{}".format(step_id)
|
||||
|
||||
if not hasattr(flow, method):
|
||||
self._progress.pop(flow.flow_id)
|
||||
raise UnknownStep("Handler {} doesn't support step {}".format(
|
||||
flow.__class__.__name__, step_id))
|
||||
|
||||
result = await getattr(flow, method)(user_input)
|
||||
|
||||
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
|
||||
RESULT_TYPE_ABORT):
|
||||
raise ValueError(
|
||||
'Handler returned incorrect type: {}'.format(result['type']))
|
||||
|
||||
if result['type'] == RESULT_TYPE_FORM:
|
||||
flow.cur_step = (result['step_id'], result['data_schema'])
|
||||
return result
|
||||
|
||||
# Abort and Success results both finish the flow
|
||||
self._progress.pop(flow.flow_id)
|
||||
|
||||
if result['type'] == RESULT_TYPE_ABORT:
|
||||
return result
|
||||
|
||||
entry = ConfigEntry(
|
||||
version=flow.VERSION,
|
||||
domain=flow.domain,
|
||||
title=result['title'],
|
||||
data=result.pop('data'),
|
||||
source=flow.source
|
||||
)
|
||||
await self._async_add_entry(entry)
|
||||
return result
|
||||
|
||||
|
||||
class ConfigFlowHandler:
|
||||
"""Handle the configuration flow of a component."""
|
||||
|
||||
# Set by flow manager
|
||||
flow_id = None
|
||||
hass = None
|
||||
domain = None
|
||||
source = SOURCE_USER
|
||||
cur_step = None
|
||||
|
||||
# Set by dev
|
||||
# VERSION
|
||||
|
||||
@callback
|
||||
def async_show_form(self, *, step_id, data_schema=None, errors=None):
|
||||
"""Return the definition of a form to gather user input."""
|
||||
return {
|
||||
'type': RESULT_TYPE_FORM,
|
||||
'flow_id': self.flow_id,
|
||||
'domain': self.domain,
|
||||
'step_id': step_id,
|
||||
'data_schema': data_schema,
|
||||
'errors': errors,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_create_entry(self, *, title, data):
|
||||
"""Finish config flow and create a config entry."""
|
||||
return {
|
||||
'type': RESULT_TYPE_CREATE_ENTRY,
|
||||
'flow_id': self.flow_id,
|
||||
'domain': self.domain,
|
||||
'title': title,
|
||||
'data': data,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_abort(self, *, reason):
|
||||
"""Abort the config flow."""
|
||||
return {
|
||||
'type': RESULT_TYPE_ABORT,
|
||||
'flow_id': self.flow_id,
|
||||
'domain': self.domain,
|
||||
'reason': reason
|
||||
}
|
||||
|
180
homeassistant/data_entry_flow.py
Normal file
180
homeassistant/data_entry_flow.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Classes to help gather user submissions."""
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .core import callback
|
||||
from .exceptions import HomeAssistantError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SOURCE_USER = 'user'
|
||||
SOURCE_DISCOVERY = 'discovery'
|
||||
|
||||
RESULT_TYPE_FORM = 'form'
|
||||
RESULT_TYPE_CREATE_ENTRY = 'create_entry'
|
||||
RESULT_TYPE_ABORT = 'abort'
|
||||
|
||||
|
||||
class FlowError(HomeAssistantError):
|
||||
"""Error while configuring an account."""
|
||||
|
||||
|
||||
class UnknownHandler(FlowError):
|
||||
"""Unknown handler specified."""
|
||||
|
||||
|
||||
class UnknownFlow(FlowError):
|
||||
"""Uknown flow specified."""
|
||||
|
||||
|
||||
class UnknownStep(FlowError):
|
||||
"""Unknown step specified."""
|
||||
|
||||
|
||||
class FlowManager:
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
def __init__(self, hass, handlers, async_missing_handler,
|
||||
async_save_entry):
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._handlers = handlers
|
||||
self._progress = {}
|
||||
self._async_missing_handler = async_missing_handler
|
||||
self._async_save_entry = async_save_entry
|
||||
|
||||
@callback
|
||||
def async_progress(self):
|
||||
"""Return the flows in progress."""
|
||||
return [{
|
||||
'flow_id': flow.flow_id,
|
||||
'domain': flow.domain,
|
||||
'source': flow.source,
|
||||
} for flow in self._progress.values()]
|
||||
|
||||
async def async_init(self, domain, *, source=SOURCE_USER, data=None):
|
||||
"""Start a configuration flow."""
|
||||
handler = self._handlers.get(domain)
|
||||
|
||||
if handler is None:
|
||||
await self._async_missing_handler(domain)
|
||||
handler = self._handlers.get(domain)
|
||||
|
||||
if handler is None:
|
||||
raise UnknownHandler
|
||||
|
||||
flow_id = uuid.uuid4().hex
|
||||
flow = self._progress[flow_id] = handler()
|
||||
flow.hass = self.hass
|
||||
flow.domain = domain
|
||||
flow.flow_id = flow_id
|
||||
flow.source = source
|
||||
|
||||
if source == SOURCE_USER:
|
||||
step = 'init'
|
||||
else:
|
||||
step = source
|
||||
|
||||
return await self._async_handle_step(flow, step, data)
|
||||
|
||||
async def async_configure(self, flow_id, user_input=None):
|
||||
"""Start or continue a configuration flow."""
|
||||
flow = self._progress.get(flow_id)
|
||||
|
||||
if flow is None:
|
||||
raise UnknownFlow
|
||||
|
||||
step_id, data_schema = flow.cur_step
|
||||
|
||||
if data_schema is not None and user_input is not None:
|
||||
user_input = data_schema(user_input)
|
||||
|
||||
return await self._async_handle_step(
|
||||
flow, step_id, user_input)
|
||||
|
||||
@callback
|
||||
def async_abort(self, flow_id):
|
||||
"""Abort a flow."""
|
||||
if self._progress.pop(flow_id, None) is None:
|
||||
raise UnknownFlow
|
||||
|
||||
async def _async_handle_step(self, flow, step_id, user_input):
|
||||
"""Handle a step of a flow."""
|
||||
method = "async_step_{}".format(step_id)
|
||||
|
||||
if not hasattr(flow, method):
|
||||
self._progress.pop(flow.flow_id)
|
||||
raise UnknownStep("Handler {} doesn't support step {}".format(
|
||||
flow.__class__.__name__, step_id))
|
||||
|
||||
result = await getattr(flow, method)(user_input)
|
||||
|
||||
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
|
||||
RESULT_TYPE_ABORT):
|
||||
raise ValueError(
|
||||
'Handler returned incorrect type: {}'.format(result['type']))
|
||||
|
||||
if result['type'] == RESULT_TYPE_FORM:
|
||||
flow.cur_step = (result['step_id'], result['data_schema'])
|
||||
return result
|
||||
|
||||
# Abort and Success results both finish the flow
|
||||
self._progress.pop(flow.flow_id)
|
||||
|
||||
if result['type'] == RESULT_TYPE_ABORT:
|
||||
return result
|
||||
|
||||
# We pass a copy of the result because we're going to mutate our
|
||||
# version afterwards and don't want to cause unexpected bugs.
|
||||
await self._async_save_entry(dict(result))
|
||||
result.pop('data')
|
||||
return result
|
||||
|
||||
|
||||
class FlowHandler:
|
||||
"""Handle the configuration flow of a component."""
|
||||
|
||||
# Set by flow manager
|
||||
flow_id = None
|
||||
hass = None
|
||||
domain = None
|
||||
source = SOURCE_USER
|
||||
cur_step = None
|
||||
|
||||
# Set by developer
|
||||
VERSION = 1
|
||||
|
||||
@callback
|
||||
def async_show_form(self, *, step_id, data_schema=None, errors=None):
|
||||
"""Return the definition of a form to gather user input."""
|
||||
return {
|
||||
'type': RESULT_TYPE_FORM,
|
||||
'flow_id': self.flow_id,
|
||||
'domain': self.domain,
|
||||
'step_id': step_id,
|
||||
'data_schema': data_schema,
|
||||
'errors': errors,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_create_entry(self, *, title, data):
|
||||
"""Finish config flow and create a config entry."""
|
||||
return {
|
||||
'version': self.VERSION,
|
||||
'type': RESULT_TYPE_CREATE_ENTRY,
|
||||
'flow_id': self.flow_id,
|
||||
'domain': self.domain,
|
||||
'title': title,
|
||||
'data': data,
|
||||
'source': self.source,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_abort(self, *, reason):
|
||||
"""Abort the config flow."""
|
||||
return {
|
||||
'type': RESULT_TYPE_ABORT,
|
||||
'flow_id': self.flow_id,
|
||||
'domain': self.domain,
|
||||
'reason': reason
|
||||
}
|
@ -10,7 +10,7 @@ import logging
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
from homeassistant import core as ha, loader, config_entries
|
||||
from homeassistant import core as ha, loader, data_entry_flow, config_entries
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.helpers import (
|
||||
@ -455,7 +455,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
|
||||
"""Helper for creating config entries that adds some defaults."""
|
||||
|
||||
def __init__(self, *, domain='test', data=None, version=0, entry_id=None,
|
||||
source=config_entries.SOURCE_USER, title='Mock Title',
|
||||
source=data_entry_flow.SOURCE_USER, title='Mock Title',
|
||||
state=None):
|
||||
"""Initialize a mock config entry."""
|
||||
kwargs = {
|
||||
|
@ -8,7 +8,8 @@ import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries as core_ce
|
||||
from homeassistant.config_entries import ConfigFlowHandler, HANDLERS
|
||||
from homeassistant.config_entries import HANDLERS
|
||||
from homeassistant.data_entry_flow import FlowHandler
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.config import config_entries
|
||||
from homeassistant.loader import set_component
|
||||
@ -93,7 +94,7 @@ def test_available_flows(hass, client):
|
||||
@asyncio.coroutine
|
||||
def test_initialize_flow(hass, client):
|
||||
"""Test we can initialize a flow."""
|
||||
class TestFlow(ConfigFlowHandler):
|
||||
class TestFlow(FlowHandler):
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
schema = OrderedDict()
|
||||
@ -142,7 +143,7 @@ def test_initialize_flow(hass, client):
|
||||
@asyncio.coroutine
|
||||
def test_abort(hass, client):
|
||||
"""Test a flow that aborts."""
|
||||
class TestFlow(ConfigFlowHandler):
|
||||
class TestFlow(FlowHandler):
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
return self.async_abort(reason='bla')
|
||||
@ -167,7 +168,7 @@ def test_create_account(hass, client):
|
||||
set_component(
|
||||
'test', MockModule('test', async_setup_entry=mock_coro_func(True)))
|
||||
|
||||
class TestFlow(ConfigFlowHandler):
|
||||
class TestFlow(FlowHandler):
|
||||
VERSION = 1
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -187,7 +188,9 @@ def test_create_account(hass, client):
|
||||
assert data == {
|
||||
'domain': 'test',
|
||||
'title': 'Test Entry',
|
||||
'type': 'create_entry'
|
||||
'type': 'create_entry',
|
||||
'source': 'user',
|
||||
'version': 1,
|
||||
}
|
||||
|
||||
|
||||
@ -197,7 +200,7 @@ def test_two_step_flow(hass, client):
|
||||
set_component(
|
||||
'test', MockModule('test', async_setup_entry=mock_coro_func(True)))
|
||||
|
||||
class TestFlow(ConfigFlowHandler):
|
||||
class TestFlow(FlowHandler):
|
||||
VERSION = 1
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -245,13 +248,15 @@ def test_two_step_flow(hass, client):
|
||||
'domain': 'test',
|
||||
'type': 'create_entry',
|
||||
'title': 'user-title',
|
||||
'version': 1,
|
||||
'source': 'user',
|
||||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_progress_index(hass, client):
|
||||
"""Test querying for the flows that are in progress."""
|
||||
class TestFlow(ConfigFlowHandler):
|
||||
class TestFlow(FlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -283,7 +288,7 @@ def test_get_progress_index(hass, client):
|
||||
@asyncio.coroutine
|
||||
def test_get_progress_flow(hass, client):
|
||||
"""Test we can query the API for same result as we get from init a flow."""
|
||||
class TestFlow(ConfigFlowHandler):
|
||||
class TestFlow(FlowHandler):
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
schema = OrderedDict()
|
||||
|
@ -5,7 +5,7 @@ from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.bootstrap import async_setup_component
|
||||
from homeassistant.components import discovery
|
||||
from homeassistant.util.dt import utcnow
|
||||
@ -174,5 +174,5 @@ async def test_discover_config_flow(hass):
|
||||
assert len(m_init.mock_calls) == 1
|
||||
args, kwargs = m_init.mock_calls[0][1:]
|
||||
assert args == ('mock-component',)
|
||||
assert kwargs['source'] == config_entries.SOURCE_DISCOVERY
|
||||
assert kwargs['source'] == data_entry_flow.SOURCE_DISCOVERY
|
||||
assert kwargs['data'] == discovery_info
|
||||
|
@ -3,9 +3,8 @@ import asyncio
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries, loader
|
||||
from homeassistant import config_entries, loader, data_entry_flow
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockModule, mock_coro, MockConfigEntry
|
||||
@ -100,7 +99,7 @@ def test_add_entry_calls_setup_entry(hass, manager):
|
||||
'comp',
|
||||
MockModule('comp', async_setup_entry=mock_setup_entry))
|
||||
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
@ -112,7 +111,7 @@ def test_add_entry_calls_setup_entry(hass, manager):
|
||||
'token': 'supersecret'
|
||||
})
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'comp': TestFlow}):
|
||||
with patch.dict(config_entries.HANDLERS, {'comp': TestFlow, 'beer': 5}):
|
||||
yield from manager.flow.async_init('comp')
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
@ -152,7 +151,7 @@ def test_domains_gets_uniques(manager):
|
||||
@asyncio.coroutine
|
||||
def test_saving_and_loading(hass):
|
||||
"""Test that we're saving and loading correctly."""
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -167,7 +166,7 @@ def test_saving_and_loading(hass):
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
yield from hass.config_entries.flow.async_init('test')
|
||||
|
||||
class Test2Flow(config_entries.ConfigFlowHandler):
|
||||
class Test2Flow(data_entry_flow.FlowHandler):
|
||||
VERSION = 3
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -212,185 +211,6 @@ def test_saving_and_loading(hass):
|
||||
assert orig.source == loaded.source
|
||||
|
||||
|
||||
#######################
|
||||
# FLOW MANAGER TESTS #
|
||||
#######################
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_configure_reuses_handler_instance(manager):
|
||||
"""Test that we reuse instances."""
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
handle_count = 0
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
self.handle_count += 1
|
||||
return self.async_show_form(
|
||||
errors={'base': str(self.handle_count)},
|
||||
step_id='init')
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
form = yield from manager.flow.async_init('test')
|
||||
assert form['errors']['base'] == '1'
|
||||
form = yield from manager.flow.async_configure(form['flow_id'])
|
||||
assert form['errors']['base'] == '2'
|
||||
assert len(manager.flow.async_progress()) == 1
|
||||
assert len(manager.async_entries()) == 0
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_configure_two_steps(manager):
|
||||
"""Test that we reuse instances."""
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
VERSION = 1
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
if user_input is not None:
|
||||
self.init_data = user_input
|
||||
return self.async_step_second()
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema([str])
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_step_second(self, user_input=None):
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title='Test Entry',
|
||||
data=self.init_data + user_input
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id='second',
|
||||
data_schema=vol.Schema([str])
|
||||
)
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
form = yield from manager.flow.async_init('test')
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
form = yield from manager.flow.async_configure(
|
||||
form['flow_id'], 'INCORRECT-DATA')
|
||||
|
||||
form = yield from manager.flow.async_configure(
|
||||
form['flow_id'], ['INIT-DATA'])
|
||||
form = yield from manager.flow.async_configure(
|
||||
form['flow_id'], ['SECOND-DATA'])
|
||||
assert form['type'] == config_entries.RESULT_TYPE_CREATE_ENTRY
|
||||
assert len(manager.flow.async_progress()) == 0
|
||||
assert len(manager.async_entries()) == 1
|
||||
entry = manager.async_entries()[0]
|
||||
assert entry.domain == 'test'
|
||||
assert entry.data == ['INIT-DATA', 'SECOND-DATA']
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_show_form(manager):
|
||||
"""Test that abort removes the flow from progress."""
|
||||
schema = vol.Schema({
|
||||
vol.Required('username'): str,
|
||||
vol.Required('password'): str
|
||||
})
|
||||
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=schema,
|
||||
errors={
|
||||
'username': 'Should be unique.'
|
||||
}
|
||||
)
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
form = yield from manager.flow.async_init('test')
|
||||
assert form['type'] == 'form'
|
||||
assert form['data_schema'] is schema
|
||||
assert form['errors'] == {
|
||||
'username': 'Should be unique.'
|
||||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_abort_removes_instance(manager):
|
||||
"""Test that abort removes the flow from progress."""
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
is_new = True
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
old = self.is_new
|
||||
self.is_new = False
|
||||
return self.async_abort(reason=str(old))
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
form = yield from manager.flow.async_init('test')
|
||||
assert form['reason'] == 'True'
|
||||
assert len(manager.flow.async_progress()) == 0
|
||||
assert len(manager.async_entries()) == 0
|
||||
form = yield from manager.flow.async_init('test')
|
||||
assert form['reason'] == 'True'
|
||||
assert len(manager.flow.async_progress()) == 0
|
||||
assert len(manager.async_entries()) == 0
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_create_saves_data(manager):
|
||||
"""Test creating a config entry."""
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_step_init(self, user_input=None):
|
||||
return self.async_create_entry(
|
||||
title='Test Title',
|
||||
data='Test Data'
|
||||
)
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
yield from manager.flow.async_init('test')
|
||||
assert len(manager.flow.async_progress()) == 0
|
||||
assert len(manager.async_entries()) == 1
|
||||
|
||||
entry = manager.async_entries()[0]
|
||||
assert entry.version == 5
|
||||
assert entry.domain == 'test'
|
||||
assert entry.title == 'Test Title'
|
||||
assert entry.data == 'Test Data'
|
||||
assert entry.source == config_entries.SOURCE_USER
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_discovery_init_flow(manager):
|
||||
"""Test a flow initialized by discovery."""
|
||||
class TestFlow(config_entries.ConfigFlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_step_discovery(self, info):
|
||||
return self.async_create_entry(title=info['id'], data=info)
|
||||
|
||||
data = {
|
||||
'id': 'hello',
|
||||
'token': 'secret'
|
||||
}
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
yield from manager.flow.async_init(
|
||||
'test', source=config_entries.SOURCE_DISCOVERY, data=data)
|
||||
assert len(manager.flow.async_progress()) == 0
|
||||
assert len(manager.async_entries()) == 1
|
||||
|
||||
entry = manager.async_entries()[0]
|
||||
assert entry.version == 5
|
||||
assert entry.domain == 'test'
|
||||
assert entry.title == 'hello'
|
||||
assert entry.data == data
|
||||
assert entry.source == config_entries.SOURCE_DISCOVERY
|
||||
|
||||
|
||||
async def test_forward_entry_sets_up_component(hass):
|
||||
"""Test we setup the component entry is forwarded to."""
|
||||
entry = MockConfigEntry(domain='original')
|
||||
|
186
tests/test_data_entry_flow.py
Normal file
186
tests/test_data_entry_flow.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""Test the flow classes."""
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager():
|
||||
"""Return a flow manager."""
|
||||
handlers = Registry()
|
||||
entries = []
|
||||
|
||||
async def async_add_entry(result):
|
||||
entries.append(result)
|
||||
|
||||
manager = data_entry_flow.FlowManager(
|
||||
None, handlers, mock_coro, async_add_entry)
|
||||
manager.mock_created_entries = entries
|
||||
manager.mock_reg_handler = handlers.register
|
||||
return manager
|
||||
|
||||
|
||||
async def test_configure_reuses_handler_instance(manager):
|
||||
"""Test that we reuse instances."""
|
||||
@manager.mock_reg_handler('test')
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
handle_count = 0
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
self.handle_count += 1
|
||||
return self.async_show_form(
|
||||
errors={'base': str(self.handle_count)},
|
||||
step_id='init')
|
||||
|
||||
form = await manager.async_init('test')
|
||||
assert form['errors']['base'] == '1'
|
||||
form = await manager.async_configure(form['flow_id'])
|
||||
assert form['errors']['base'] == '2'
|
||||
assert len(manager.async_progress()) == 1
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
|
||||
|
||||
async def test_configure_two_steps(manager):
|
||||
"""Test that we reuse instances."""
|
||||
@manager.mock_reg_handler('test')
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
if user_input is not None:
|
||||
self.init_data = user_input
|
||||
return await self.async_step_second()
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=vol.Schema([str])
|
||||
)
|
||||
|
||||
async def async_step_second(self, user_input=None):
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title='Test Entry',
|
||||
data=self.init_data + user_input
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id='second',
|
||||
data_schema=vol.Schema([str])
|
||||
)
|
||||
|
||||
form = await manager.async_init('test')
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
form = await manager.async_configure(
|
||||
form['flow_id'], 'INCORRECT-DATA')
|
||||
|
||||
form = await manager.async_configure(
|
||||
form['flow_id'], ['INIT-DATA'])
|
||||
form = await manager.async_configure(
|
||||
form['flow_id'], ['SECOND-DATA'])
|
||||
assert form['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 1
|
||||
result = manager.mock_created_entries[0]
|
||||
assert result['domain'] == 'test'
|
||||
assert result['data'] == ['INIT-DATA', 'SECOND-DATA']
|
||||
|
||||
|
||||
async def test_show_form(manager):
|
||||
"""Test that abort removes the flow from progress."""
|
||||
schema = vol.Schema({
|
||||
vol.Required('username'): str,
|
||||
vol.Required('password'): str
|
||||
})
|
||||
|
||||
@manager.mock_reg_handler('test')
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_form(
|
||||
step_id='init',
|
||||
data_schema=schema,
|
||||
errors={
|
||||
'username': 'Should be unique.'
|
||||
}
|
||||
)
|
||||
|
||||
form = await manager.async_init('test')
|
||||
assert form['type'] == 'form'
|
||||
assert form['data_schema'] is schema
|
||||
assert form['errors'] == {
|
||||
'username': 'Should be unique.'
|
||||
}
|
||||
|
||||
|
||||
async def test_abort_removes_instance(manager):
|
||||
"""Test that abort removes the flow from progress."""
|
||||
@manager.mock_reg_handler('test')
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
is_new = True
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
old = self.is_new
|
||||
self.is_new = False
|
||||
return self.async_abort(reason=str(old))
|
||||
|
||||
form = await manager.async_init('test')
|
||||
assert form['reason'] == 'True'
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
form = await manager.async_init('test')
|
||||
assert form['reason'] == 'True'
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
|
||||
|
||||
async def test_create_saves_data(manager):
|
||||
"""Test creating a config entry."""
|
||||
@manager.mock_reg_handler('test')
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_create_entry(
|
||||
title='Test Title',
|
||||
data='Test Data'
|
||||
)
|
||||
|
||||
await manager.async_init('test')
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 1
|
||||
|
||||
entry = manager.mock_created_entries[0]
|
||||
assert entry['version'] == 5
|
||||
assert entry['domain'] == 'test'
|
||||
assert entry['title'] == 'Test Title'
|
||||
assert entry['data'] == 'Test Data'
|
||||
assert entry['source'] == data_entry_flow.SOURCE_USER
|
||||
|
||||
|
||||
async def test_discovery_init_flow(manager):
|
||||
"""Test a flow initialized by discovery."""
|
||||
@manager.mock_reg_handler('test')
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_discovery(self, info):
|
||||
return self.async_create_entry(title=info['id'], data=info)
|
||||
|
||||
data = {
|
||||
'id': 'hello',
|
||||
'token': 'secret'
|
||||
}
|
||||
|
||||
await manager.async_init(
|
||||
'test', source=data_entry_flow.SOURCE_DISCOVERY, data=data)
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 1
|
||||
|
||||
entry = manager.mock_created_entries[0]
|
||||
assert entry['version'] == 5
|
||||
assert entry['domain'] == 'test'
|
||||
assert entry['title'] == 'hello'
|
||||
assert entry['data'] == data
|
||||
assert entry['source'] == data_entry_flow.SOURCE_DISCOVERY
|
Loading…
x
Reference in New Issue
Block a user