Rework FlowManager to use inheritance (#30133)

* Pull async_finish_flow/async_create_flow out of ConfigEntries

* Towards refactoring

* mypy fixes

* Mark Flow manager with abc.* annotations

* Flake8 fixes

* Mypy fixes

* Blacken data_entry_flow

* Blacken longer signatures caused by mypy changes

* test fixes

* Test fixes

* Fix typo

* Avoid protected member lint (W0212) in config_entries

* More protected member fixes

* Missing await
This commit is contained in:
Jc2k 2020-01-03 10:52:01 +00:00 committed by Paulus Schoutsen
parent 0a4f3ec1ec
commit fdfedd086b
12 changed files with 313 additions and 258 deletions

View File

@ -67,6 +67,69 @@ async def auth_manager_from_config(
return manager return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager):
"""Manage authentication flows."""
def __init__(self, hass: HomeAssistant, auth_manager: "AuthManager"):
"""Init auth manager flows."""
super().__init__(hass)
self.auth_manager = auth_manager
async def async_create_flow(
self,
handler_key: Any,
*,
context: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> data_entry_flow.FlowHandler:
"""Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider:
raise KeyError(f"Unknown auth provider {handler_key}")
return await auth_provider.async_login_flow(context)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Return a user as result of login flow."""
flow = cast(LoginFlow, flow)
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# we got final result
if isinstance(result["data"], models.User):
result["result"] = result["data"]
return result
auth_provider = self.auth_manager.get_auth_provider(*result["handler"])
if not auth_provider:
raise KeyError(f"Unknown auth provider {result['handler']}")
credentials = await auth_provider.async_get_or_create_credentials(
result["data"]
)
if flow.context.get("credential_only"):
result["result"] = credentials
return result
# multi-factor module cannot enabled for new credential
# which has not linked to a user yet
if auth_provider.support_mfa and not credentials.is_new:
user = await self.auth_manager.async_get_user_by_credentials(credentials)
if user is not None:
modules = await self.auth_manager.async_get_enabled_mfa(user)
if modules:
flow.user = user
flow.available_mfa_modules = modules
return await flow.async_step_select_mfa_module()
result["result"] = await self.auth_manager.async_get_or_create_user(credentials)
return result
class AuthManager: class AuthManager:
"""Manage the authentication for Home Assistant.""" """Manage the authentication for Home Assistant."""
@ -82,9 +145,7 @@ class AuthManager:
self._store = store self._store = store
self._providers = providers self._providers = providers
self._mfa_modules = mfa_modules self._mfa_modules = mfa_modules
self.login_flow = data_entry_flow.FlowManager( self.login_flow = AuthManagerFlowManager(hass, self)
hass, self._async_create_login_flow, self._async_finish_login_flow
)
@property @property
def auth_providers(self) -> List[AuthProvider]: def auth_providers(self) -> List[AuthProvider]:
@ -417,50 +478,6 @@ class AuthManager:
return refresh_token return refresh_token
async def _async_create_login_flow(
self, handler: _ProviderKey, *, context: Optional[Dict], data: Optional[Any]
) -> data_entry_flow.FlowHandler:
"""Create a login flow."""
auth_provider = self._providers[handler]
return await auth_provider.async_login_flow(context)
async def _async_finish_login_flow(
self, flow: LoginFlow, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Return a user as result of login flow."""
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# we got final result
if isinstance(result["data"], models.User):
result["result"] = result["data"]
return result
auth_provider = self._providers[result["handler"]]
credentials = await auth_provider.async_get_or_create_credentials(
result["data"]
)
if flow.context.get("credential_only"):
result["result"] = credentials
return result
# multi-factor module cannot enabled for new credential
# which has not linked to a user yet
if auth_provider.support_mfa and not credentials.is_new:
user = await self.async_get_user_by_credentials(credentials)
if user is not None:
modules = await self.async_get_enabled_mfa(user)
if modules:
flow.user = user
flow.available_mfa_modules = modules
return await flow.async_step_select_mfa_module()
result["result"] = await self.async_get_or_create_user(credentials)
return result
@callback @callback
def _async_get_auth_provider( def _async_get_auth_provider(
self, credentials: models.Credentials self, credentials: models.Credentials

View File

@ -28,25 +28,27 @@ DATA_SETUP_FLOW_MGR = "auth_mfa_setup_flow_manager"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup(hass): class MfaFlowManager(data_entry_flow.FlowManager):
"""Init mfa setup flow manager.""" """Manage multi factor authentication flows."""
async def _async_create_setup_flow(handler, context, data): async def async_create_flow(self, handler_key, *, context, data):
"""Create a setup flow. handler is a mfa module.""" """Create a setup flow. handler is a mfa module."""
mfa_module = hass.auth.get_auth_mfa_module(handler) mfa_module = self.hass.auth.get_auth_mfa_module(handler_key)
if mfa_module is None: if mfa_module is None:
raise ValueError(f"Mfa module {handler} is not found") raise ValueError(f"Mfa module {handler_key} is not found")
user_id = data.pop("user_id") user_id = data.pop("user_id")
return await mfa_module.async_setup_flow(user_id) return await mfa_module.async_setup_flow(user_id)
async def _async_finish_setup_flow(flow, flow_result): async def async_finish_flow(self, flow, result):
_LOGGER.debug("flow_result: %s", flow_result) """Complete an mfs setup flow."""
return flow_result _LOGGER.debug("flow_result: %s", result)
return result
hass.data[DATA_SETUP_FLOW_MGR] = data_entry_flow.FlowManager(
hass, _async_create_setup_flow, _async_finish_setup_flow async def async_setup(hass):
) """Init mfa setup flow manager."""
hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass)
hass.components.websocket_api.async_register_command( hass.components.websocket_api.async_register_command(
WS_TYPE_SETUP_MFA, websocket_setup_mfa, SCHEMA_WS_SETUP_MFA WS_TYPE_SETUP_MFA, websocket_setup_mfa, SCHEMA_WS_SETUP_MFA

View File

@ -23,12 +23,8 @@ async def async_setup(hass):
hass.http.register_view(ConfigManagerFlowResourceView(hass.config_entries.flow)) hass.http.register_view(ConfigManagerFlowResourceView(hass.config_entries.flow))
hass.http.register_view(ConfigManagerAvailableFlowView) hass.http.register_view(ConfigManagerAvailableFlowView)
hass.http.register_view( hass.http.register_view(OptionManagerFlowIndexView(hass.config_entries.options))
OptionManagerFlowIndexView(hass.config_entries.options.flow) hass.http.register_view(OptionManagerFlowResourceView(hass.config_entries.options))
)
hass.http.register_view(
OptionManagerFlowResourceView(hass.config_entries.options.flow)
)
hass.components.websocket_api.async_register_command(config_entries_progress) hass.components.websocket_api.async_register_command(config_entries_progress)
hass.components.websocket_api.async_register_command(system_options_list) hass.components.websocket_api.async_register_command(system_options_list)

View File

@ -399,6 +399,137 @@ class ConfigEntry:
} }
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
"""Manage all the config entry flows that are in progress."""
def __init__(
self, hass: HomeAssistant, config_entries: "ConfigEntries", hass_config: dict
):
"""Initialize the config entry flow manager."""
super().__init__(hass)
self.config_entries = config_entries
self._hass_config = hass_config
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow)
# Remove notification if no other discovery config entries in progress
if not any(
ent["context"]["source"] in DISCOVERY_SOURCES
for ent in self.hass.config_entries.flow.async_progress()
if ent["flow_id"] != flow.flow_id
):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID
)
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# Check if config entry exists with unique ID. Unload it.
existing_entry = None
if flow.unique_id is not None:
# Abort all flows in progress with same unique ID.
for progress_flow in self.async_progress():
if (
progress_flow["handler"] == flow.handler
and progress_flow["flow_id"] != flow.flow_id
and progress_flow["context"].get("unique_id") == flow.unique_id
):
self.async_abort(progress_flow["flow_id"])
# Find existing entry.
for check_entry in self.config_entries.async_entries(result["handler"]):
if check_entry.unique_id == flow.unique_id:
existing_entry = check_entry
break
# Unload the entry before setting up the new one.
# We will remove it only after the other one is set up,
# so that device customizations are not getting lost.
if (
existing_entry is not None
and existing_entry.state not in UNRECOVERABLE_STATES
):
await self.config_entries.async_unload(existing_entry.entry_id)
entry = ConfigEntry(
version=result["version"],
domain=result["handler"],
title=result["title"],
data=result["data"],
options={},
system_options={},
source=flow.context["source"],
connection_class=flow.CONNECTION_CLASS,
unique_id=flow.unique_id,
)
await self.config_entries.async_add(entry)
if existing_entry is not None:
await self.config_entries.async_remove(existing_entry.entry_id)
result["result"] = entry
return result
async def async_create_flow(
self, handler_key: Any, *, context: Optional[Dict] = None, data: Any = None
) -> "ConfigFlow":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
"""
try:
integration = await loader.async_get_integration(self.hass, handler_key)
except loader.IntegrationNotFound:
_LOGGER.error("Cannot find integration %s", handler_key)
raise data_entry_flow.UnknownHandler
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(self.hass, self._hass_config, integration)
try:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error occurred loading config flow for integration %s: %s",
handler_key,
err,
)
raise data_entry_flow.UnknownHandler
handler = HANDLERS.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set")
source = context["source"]
# Create notification.
if source in DISCOVERY_SOURCES:
self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED)
self.hass.components.persistent_notification.async_create(
title="New devices discovered",
message=(
"We have discovered new devices on your network. "
"[Check it out](/config/integrations)"
),
notification_id=DISCOVERY_NOTIFICATION_ID,
)
flow = cast(ConfigFlow, handler())
flow.init_step = source
return flow
class ConfigEntries: class ConfigEntries:
"""Manage the configuration entries. """Manage the configuration entries.
@ -408,9 +539,7 @@ class ConfigEntries:
def __init__(self, hass: HomeAssistant, hass_config: dict) -> None: def __init__(self, hass: HomeAssistant, hass_config: dict) -> None:
"""Initialize the entry manager.""" """Initialize the entry manager."""
self.hass = hass self.hass = hass
self.flow = data_entry_flow.FlowManager( self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
hass, self._async_create_flow, self._async_finish_flow
)
self.options = OptionsFlowManager(hass) self.options = OptionsFlowManager(hass)
self._hass_config = hass_config self._hass_config = hass_config
self._entries: List[ConfigEntry] = [] self._entries: List[ConfigEntry] = []
@ -445,6 +574,12 @@ class ConfigEntries:
return list(self._entries) return list(self._entries)
return [entry for entry in self._entries if entry.domain == domain] return [entry for entry in self._entries if entry.domain == domain]
async def async_add(self, entry: ConfigEntry) -> None:
"""Add and setup an entry."""
self._entries.append(entry)
await self.async_setup(entry.entry_id)
self._async_schedule_save()
async def async_remove(self, entry_id: str) -> Dict[str, Any]: async def async_remove(self, entry_id: str) -> Dict[str, Any]:
"""Remove an entry.""" """Remove an entry."""
entry = self.async_get_entry(entry_id) entry = self.async_get_entry(entry_id)
@ -630,123 +765,6 @@ class ConfigEntries:
return await entry.async_unload(self.hass, integration=integration) return await entry.async_unload(self.hass, integration=integration)
async def _async_finish_flow(
self, flow: "ConfigFlow", result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
# Remove notification if no other discovery config entries in progress
if not any(
ent["context"]["source"] in DISCOVERY_SOURCES
for ent in self.hass.config_entries.flow.async_progress()
if ent["flow_id"] != flow.flow_id
):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID
)
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result
# Check if config entry exists with unique ID. Unload it.
existing_entry = None
if flow.unique_id is not None:
# Abort all flows in progress with same unique ID.
for progress_flow in self.flow.async_progress():
if (
progress_flow["handler"] == flow.handler
and progress_flow["flow_id"] != flow.flow_id
and progress_flow["context"].get("unique_id") == flow.unique_id
):
self.flow.async_abort(progress_flow["flow_id"])
# Find existing entry.
for check_entry in self.async_entries(result["handler"]):
if check_entry.unique_id == flow.unique_id:
existing_entry = check_entry
break
# Unload the entry before setting up the new one.
# We will remove it only after the other one is set up,
# so that device customizations are not getting lost.
if (
existing_entry is not None
and existing_entry.state not in UNRECOVERABLE_STATES
):
await self.async_unload(existing_entry.entry_id)
entry = ConfigEntry(
version=result["version"],
domain=result["handler"],
title=result["title"],
data=result["data"],
options={},
system_options={},
source=flow.context["source"],
connection_class=flow.CONNECTION_CLASS,
unique_id=flow.unique_id,
)
self._entries.append(entry)
await self.async_setup(entry.entry_id)
if existing_entry is not None:
await self.async_remove(existing_entry.entry_id)
self._async_schedule_save()
result["result"] = entry
return result
async def _async_create_flow(
self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> "ConfigFlow":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
"""
try:
integration = await loader.async_get_integration(self.hass, handler_key)
except loader.IntegrationNotFound:
_LOGGER.error("Cannot find integration %s", handler_key)
raise data_entry_flow.UnknownHandler
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(self.hass, self._hass_config, integration)
try:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error occurred loading config flow for integration %s: %s",
handler_key,
err,
)
raise data_entry_flow.UnknownHandler
handler = HANDLERS.get(handler_key)
if handler is None:
raise data_entry_flow.UnknownHandler
source = context["source"]
# Create notification.
if source in DISCOVERY_SOURCES:
self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED)
self.hass.components.persistent_notification.async_create(
title="New devices discovered",
message=(
"We have discovered new devices on your network. "
"[Check it out](/config/integrations)"
),
notification_id=DISCOVERY_NOTIFICATION_ID,
)
flow = cast(ConfigFlow, handler())
flow.init_step = source
return flow
def _async_schedule_save(self) -> None: def _async_schedule_save(self) -> None:
"""Save the entity registry to a file.""" """Save the entity registry to a file."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@ -854,26 +872,23 @@ class ConfigFlow(data_entry_flow.FlowHandler):
return self.async_abort(reason="not_implemented") return self.async_abort(reason="not_implemented")
class OptionsFlowManager: class OptionsFlowManager(data_entry_flow.FlowManager):
"""Flow to set options for a configuration entry.""" """Flow to set options for a configuration entry."""
def __init__(self, hass: HomeAssistant) -> None: async def async_create_flow(
"""Initialize the options manager.""" self,
self.hass = hass handler_key: Any,
self.flow = data_entry_flow.FlowManager( *,
hass, self._async_create_flow, self._async_finish_flow context: Optional[Dict[str, Any]] = None,
) data: Optional[Dict[str, Any]] = None,
) -> "OptionsFlow":
async def _async_create_flow(
self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> Optional["OptionsFlow"]:
"""Create an options flow for a config entry. """Create an options flow for a config entry.
Entry_id and flow.handler is the same thing to map entry with flow. Entry_id and flow.handler is the same thing to map entry with flow.
""" """
entry = self.hass.config_entries.async_get_entry(entry_id) entry = self.hass.config_entries.async_get_entry(handler_key)
if entry is None: if entry is None:
return None raise UnknownEntry(handler_key)
if entry.domain not in HANDLERS: if entry.domain not in HANDLERS:
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
@ -881,16 +896,18 @@ class OptionsFlowManager:
flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry)) flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
return flow return flow
async def _async_finish_flow( async def async_finish_flow(
self, flow: "OptionsFlow", result: Dict[str, Any] self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Optional[Dict[str, Any]]: ) -> Dict[str, Any]:
"""Finish an options flow and update options for configuration entry. """Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry. Flow.handler and entry_id is the same thing to map flow with entry.
""" """
flow = cast(OptionsFlow, flow)
entry = self.hass.config_entries.async_get_entry(flow.handler) entry = self.hass.config_entries.async_get_entry(flow.handler)
if entry is None: if entry is None:
return None raise UnknownEntry(flow.handler)
self.hass.config_entries.async_update_entry(entry, options=result["data"]) self.hass.config_entries.async_update_entry(entry, options=result["data"])
result["result"] = True result["result"] = True

View File

@ -1,6 +1,7 @@
"""Classes to help gather user submissions.""" """Classes to help gather user submissions."""
import abc
import logging import logging
from typing import Any, Callable, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
import uuid import uuid
import voluptuous as vol import voluptuous as vol
@ -46,20 +47,34 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders self.description_placeholders = description_placeholders
class FlowManager: class FlowManager(abc.ABC):
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
def __init__( def __init__(self, hass: HomeAssistant,) -> None:
self,
hass: HomeAssistant,
async_create_flow: Callable,
async_finish_flow: Callable,
) -> None:
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._progress: Dict[str, Any] = {} self._progress: Dict[str, Any] = {}
self._async_create_flow = async_create_flow
self._async_finish_flow = async_finish_flow @abc.abstractmethod
async def async_create_flow(
self,
handler_key: Any,
*,
context: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> "FlowHandler":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
"""
pass
@abc.abstractmethod
async def async_finish_flow(
self, flow: "FlowHandler", result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
pass
@callback @callback
def async_progress(self) -> List[Dict]: def async_progress(self) -> List[Dict]:
@ -75,7 +90,9 @@ class FlowManager:
"""Start a configuration flow.""" """Start a configuration flow."""
if context is None: if context is None:
context = {} context = {}
flow = await self._async_create_flow(handler, context=context, data=data) flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise UnknownFlow("Flow was not created")
flow.hass = self.hass flow.hass = self.hass
flow.handler = handler flow.handler = handler
flow.flow_id = uuid.uuid4().hex flow.flow_id = uuid.uuid4().hex
@ -168,7 +185,7 @@ class FlowManager:
return result return result
# We pass a copy of the result because we're mutating our version # We pass a copy of the result because we're mutating our version
result = await self._async_finish_flow(flow, dict(result)) result = await self.async_finish_flow(flow, dict(result))
# _async_finish_flow may change result type, check it again # _async_finish_flow may change result type, check it again
if result["type"] == RESULT_TYPE_FORM: if result["type"] == RESULT_TYPE_FORM:

View File

@ -436,7 +436,7 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None) entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None)
hass.config_entries._entries.append(entry) hass.config_entries._entries.append(entry)
flow = await hass.config_entries.options._async_create_flow( flow = await hass.config_entries.options.async_create_flow(
entry.entry_id, context={"source": "test"}, data=None entry.entry_id, context={"source": "test"}, data=None
) )

View File

@ -182,13 +182,13 @@ async def test_options_form(hass):
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init( result = await hass.config_entries.options.async_init(
entry.entry_id, context={"source": "test"}, data=None entry.entry_id, context={"source": "test"}, data=None
) )
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "init" assert result["step_id"] == "init"
result = await hass.config_entries.options.flow.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={CONF_FLOOR_TEMP: True, CONF_PRECISION: PRECISION_HALVES}, user_input={CONF_FLOOR_TEMP: True, CONF_PRECISION: PRECISION_HALVES},
) )
@ -197,11 +197,11 @@ async def test_options_form(hass):
assert result["data"][CONF_PRECISION] == PRECISION_HALVES assert result["data"][CONF_PRECISION] == PRECISION_HALVES
assert result["data"][CONF_FLOOR_TEMP] is True assert result["data"][CONF_FLOOR_TEMP] is True
result = await hass.config_entries.options.flow.async_init( result = await hass.config_entries.options.async_init(
entry.entry_id, context={"source": "test"}, data=None entry.entry_id, context={"source": "test"}, data=None
) )
result = await hass.config_entries.options.flow.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_PRECISION: 0} result["flow_id"], user_input={CONF_PRECISION: 0}
) )

View File

@ -462,13 +462,13 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=DEFAULT_OPTIONS) entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=DEFAULT_OPTIONS)
entry.add_to_hass(hass) entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init( result = await hass.config_entries.options.async_init(
entry.entry_id, context={"source": "test"}, data=None entry.entry_id, context={"source": "test"}, data=None
) )
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "plex_mp_settings" assert result["step_id"] == "plex_mp_settings"
result = await hass.config_entries.options.flow.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={ user_input={
config_flow.CONF_USE_EPISODE_ART: True, config_flow.CONF_USE_EPISODE_ART: True,

View File

@ -131,12 +131,12 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=DOMAIN, data={}, options=None) entry = MockConfigEntry(domain=DOMAIN, data={}, options=None)
entry.add_to_hass(hass) entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init(entry.entry_id) result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "init" assert result["step_id"] == "init"
result = await hass.config_entries.options.flow.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_SCAN_INTERVAL: 350} result["flow_id"], user_input={CONF_SCAN_INTERVAL: 350}
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
@ -148,12 +148,12 @@ async def test_option_flow_input_floor(hass):
entry = MockConfigEntry(domain=DOMAIN, data={}, options=None) entry = MockConfigEntry(domain=DOMAIN, data={}, options=None)
entry.add_to_hass(hass) entry.add_to_hass(hass)
result = await hass.config_entries.options.flow.async_init(entry.entry_id) result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "init" assert result["step_id"] == "init"
result = await hass.config_entries.options.flow.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], user_input={CONF_SCAN_INTERVAL: 1} result["flow_id"], user_input={CONF_SCAN_INTERVAL: 1}
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"

View File

@ -231,7 +231,7 @@ async def test_option_flow(hass):
entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None) entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None)
hass.config_entries._entries.append(entry) hass.config_entries._entries.append(entry)
flow = await hass.config_entries.options._async_create_flow( flow = await hass.config_entries.options.async_create_flow(
entry.entry_id, context={"source": "test"}, data=None entry.entry_id, context={"source": "test"}, data=None
) )

View File

@ -692,13 +692,13 @@ async def test_entry_options(hass, manager):
return OptionsFlowHandler() return OptionsFlowHandler()
config_entries.HANDLERS["test"] = TestFlow() config_entries.HANDLERS["test"] = TestFlow()
flow = await manager.options._async_create_flow( flow = await manager.options.async_create_flow(
entry.entry_id, context={"source": "test"}, data=None entry.entry_id, context={"source": "test"}, data=None
) )
flow.handler = entry.entry_id # Used to keep reference to config entry flow.handler = entry.entry_id # Used to keep reference to config entry
await manager.options._async_finish_flow(flow, {"data": {"second": True}}) await manager.options.async_finish_flow(flow, {"data": {"second": True}})
assert entry.data == {"first": True} assert entry.data == {"first": True}

View File

@ -14,27 +14,32 @@ def manager():
handlers = Registry() handlers = Registry()
entries = [] entries = []
async def async_create_flow(handler_name, *, context, data): class FlowManager(data_entry_flow.FlowManager):
handler = handlers.get(handler_name) """Test flow manager."""
if handler is None: async def async_create_flow(self, handler_key, *, context, data):
raise data_entry_flow.UnknownHandler """Test create flow."""
handler = handlers.get(handler_key)
flow = handler() if handler is None:
flow.init_step = context.get("init_step", "init") raise data_entry_flow.UnknownHandler
flow.source = context.get("source")
return flow
async def async_add_entry(flow, result): flow = handler()
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: flow.init_step = context.get("init_step", "init")
result["source"] = flow.context.get("source") flow.source = context.get("source")
entries.append(result) return flow
return result
manager = data_entry_flow.FlowManager(None, async_create_flow, async_add_entry) async def async_finish_flow(self, flow, result):
manager.mock_created_entries = entries """Test finish flow."""
manager.mock_reg_handler = handlers.register if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return manager result["source"] = flow.context.get("source")
entries.append(result)
return result
mgr = FlowManager(None)
mgr.mock_created_entries = entries
mgr.mock_reg_handler = handlers.register
return mgr
async def test_configure_reuses_handler_instance(manager): async def test_configure_reuses_handler_instance(manager):
@ -194,22 +199,23 @@ async def test_finish_callback_change_result_type(hass):
step_id="init", data_schema=vol.Schema({"count": int}) step_id="init", data_schema=vol.Schema({"count": int})
) )
async def async_create_flow(handler_name, *, context, data): class FlowManager(data_entry_flow.FlowManager):
"""Create a test flow.""" async def async_create_flow(self, handler_name, *, context, data):
return TestFlow() """Create a test flow."""
return TestFlow()
async def async_finish_flow(flow, result): async def async_finish_flow(self, flow, result):
"""Redirect to init form if count <= 1.""" """Redirect to init form if count <= 1."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
if result["data"] is None or result["data"].get("count", 0) <= 1: if result["data"] is None or result["data"].get("count", 0) <= 1:
return flow.async_show_form( return flow.async_show_form(
step_id="init", data_schema=vol.Schema({"count": int}) step_id="init", data_schema=vol.Schema({"count": int})
) )
else: else:
result["result"] = result["data"]["count"] result["result"] = result["data"]["count"]
return result return result
manager = data_entry_flow.FlowManager(hass, async_create_flow, async_finish_flow) manager = FlowManager(hass)
result = await manager.async_init("test") result = await manager.async_init("test")
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM