From fdfedd086bcc6f19bf8e3023381f5c0d12f30624 Mon Sep 17 00:00:00 2001 From: Jc2k Date: Fri, 3 Jan 2020 10:52:01 +0000 Subject: [PATCH] Rework FlowManager to use inheritance (#30133) * Pull async_finish_flow/async_create_flow out of ConfigEntries * Towards refactoring * mypy fixes * Mark Flow manager with abc.* annotations * Flake8 fixes * Mypy fixes * Blacken data_entry_flow * Blacken longer signatures caused by mypy changes * test fixes * Test fixes * Fix typo * Avoid protected member lint (W0212) in config_entries * More protected member fixes * Missing await --- homeassistant/auth/__init__.py | 111 ++++--- .../components/auth/mfa_setup_flow.py | 24 +- .../components/config/config_entries.py | 8 +- homeassistant/config_entries.py | 291 +++++++++--------- homeassistant/data_entry_flow.py | 41 ++- tests/components/deconz/test_config_flow.py | 2 +- .../opentherm_gw/test_config_flow.py | 8 +- tests/components/plex/test_config_flow.py | 4 +- tests/components/tesla/test_config_flow.py | 8 +- tests/components/unifi/test_config_flow.py | 2 +- tests/test_config_entries.py | 4 +- tests/test_data_entry_flow.py | 68 ++-- 12 files changed, 313 insertions(+), 258 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index e4437bea840..9b3cf49fa22 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -67,6 +67,69 @@ async def auth_manager_from_config( return manager +class AuthManagerFlowManager(data_entry_flow.FlowManager): + """Manage authentication flows.""" + + def __init__(self, hass: HomeAssistant, auth_manager: "AuthManager"): + """Init auth manager flows.""" + super().__init__(hass) + self.auth_manager = auth_manager + + async def async_create_flow( + self, + handler_key: Any, + *, + context: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> data_entry_flow.FlowHandler: + """Create a login flow.""" + auth_provider = self.auth_manager.get_auth_provider(*handler_key) + if not auth_provider: + raise KeyError(f"Unknown auth provider {handler_key}") + return await auth_provider.async_login_flow(context) + + async def async_finish_flow( + self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any] + ) -> Dict[str, Any]: + """Return a user as result of login flow.""" + flow = cast(LoginFlow, flow) + + if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + return result + + # we got final result + if isinstance(result["data"], models.User): + result["result"] = result["data"] + return result + + auth_provider = self.auth_manager.get_auth_provider(*result["handler"]) + if not auth_provider: + raise KeyError(f"Unknown auth provider {result['handler']}") + + credentials = await auth_provider.async_get_or_create_credentials( + result["data"] + ) + + if flow.context.get("credential_only"): + result["result"] = credentials + return result + + # multi-factor module cannot enabled for new credential + # which has not linked to a user yet + if auth_provider.support_mfa and not credentials.is_new: + user = await self.auth_manager.async_get_user_by_credentials(credentials) + if user is not None: + modules = await self.auth_manager.async_get_enabled_mfa(user) + + if modules: + flow.user = user + flow.available_mfa_modules = modules + return await flow.async_step_select_mfa_module() + + result["result"] = await self.auth_manager.async_get_or_create_user(credentials) + return result + + class AuthManager: """Manage the authentication for Home Assistant.""" @@ -82,9 +145,7 @@ class AuthManager: self._store = store self._providers = providers self._mfa_modules = mfa_modules - self.login_flow = data_entry_flow.FlowManager( - hass, self._async_create_login_flow, self._async_finish_login_flow - ) + self.login_flow = AuthManagerFlowManager(hass, self) @property def auth_providers(self) -> List[AuthProvider]: @@ -417,50 +478,6 @@ class AuthManager: return refresh_token - async def _async_create_login_flow( - self, handler: _ProviderKey, *, context: Optional[Dict], data: Optional[Any] - ) -> data_entry_flow.FlowHandler: - """Create a login flow.""" - auth_provider = self._providers[handler] - - return await auth_provider.async_login_flow(context) - - async def _async_finish_login_flow( - self, flow: LoginFlow, result: Dict[str, Any] - ) -> Dict[str, Any]: - """Return a user as result of login flow.""" - if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: - return result - - # we got final result - if isinstance(result["data"], models.User): - result["result"] = result["data"] - return result - - auth_provider = self._providers[result["handler"]] - credentials = await auth_provider.async_get_or_create_credentials( - result["data"] - ) - - if flow.context.get("credential_only"): - result["result"] = credentials - return result - - # multi-factor module cannot enabled for new credential - # which has not linked to a user yet - if auth_provider.support_mfa and not credentials.is_new: - user = await self.async_get_user_by_credentials(credentials) - if user is not None: - modules = await self.async_get_enabled_mfa(user) - - if modules: - flow.user = user - flow.available_mfa_modules = modules - return await flow.async_step_select_mfa_module() - - result["result"] = await self.async_get_or_create_user(credentials) - return result - @callback def _async_get_auth_provider( self, credentials: models.Credentials diff --git a/homeassistant/components/auth/mfa_setup_flow.py b/homeassistant/components/auth/mfa_setup_flow.py index 92926e2e7c5..1b199551a14 100644 --- a/homeassistant/components/auth/mfa_setup_flow.py +++ b/homeassistant/components/auth/mfa_setup_flow.py @@ -28,25 +28,27 @@ DATA_SETUP_FLOW_MGR = "auth_mfa_setup_flow_manager" _LOGGER = logging.getLogger(__name__) -async def async_setup(hass): - """Init mfa setup flow manager.""" +class MfaFlowManager(data_entry_flow.FlowManager): + """Manage multi factor authentication flows.""" - async def _async_create_setup_flow(handler, context, data): + async def async_create_flow(self, handler_key, *, context, data): """Create a setup flow. handler is a mfa module.""" - mfa_module = hass.auth.get_auth_mfa_module(handler) + mfa_module = self.hass.auth.get_auth_mfa_module(handler_key) if mfa_module is None: - raise ValueError(f"Mfa module {handler} is not found") + raise ValueError(f"Mfa module {handler_key} is not found") user_id = data.pop("user_id") return await mfa_module.async_setup_flow(user_id) - async def _async_finish_setup_flow(flow, flow_result): - _LOGGER.debug("flow_result: %s", flow_result) - return flow_result + async def async_finish_flow(self, flow, result): + """Complete an mfs setup flow.""" + _LOGGER.debug("flow_result: %s", result) + return result - hass.data[DATA_SETUP_FLOW_MGR] = data_entry_flow.FlowManager( - hass, _async_create_setup_flow, _async_finish_setup_flow - ) + +async def async_setup(hass): + """Init mfa setup flow manager.""" + hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass) hass.components.websocket_api.async_register_command( WS_TYPE_SETUP_MFA, websocket_setup_mfa, SCHEMA_WS_SETUP_MFA diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index dbf0ee8f283..22df26cce4e 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -23,12 +23,8 @@ async def async_setup(hass): hass.http.register_view(ConfigManagerFlowResourceView(hass.config_entries.flow)) hass.http.register_view(ConfigManagerAvailableFlowView) - hass.http.register_view( - OptionManagerFlowIndexView(hass.config_entries.options.flow) - ) - hass.http.register_view( - OptionManagerFlowResourceView(hass.config_entries.options.flow) - ) + hass.http.register_view(OptionManagerFlowIndexView(hass.config_entries.options)) + hass.http.register_view(OptionManagerFlowResourceView(hass.config_entries.options)) hass.components.websocket_api.async_register_command(config_entries_progress) hass.components.websocket_api.async_register_command(system_options_list) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 942998767a1..d1b5c927a2b 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -399,6 +399,137 @@ class ConfigEntry: } +class ConfigEntriesFlowManager(data_entry_flow.FlowManager): + """Manage all the config entry flows that are in progress.""" + + def __init__( + self, hass: HomeAssistant, config_entries: "ConfigEntries", hass_config: dict + ): + """Initialize the config entry flow manager.""" + super().__init__(hass) + self.config_entries = config_entries + self._hass_config = hass_config + + async def async_finish_flow( + self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any] + ) -> Dict[str, Any]: + """Finish a config flow and add an entry.""" + flow = cast(ConfigFlow, flow) + + # Remove notification if no other discovery config entries in progress + if not any( + ent["context"]["source"] in DISCOVERY_SOURCES + for ent in self.hass.config_entries.flow.async_progress() + if ent["flow_id"] != flow.flow_id + ): + self.hass.components.persistent_notification.async_dismiss( + DISCOVERY_NOTIFICATION_ID + ) + + if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + return result + + # Check if config entry exists with unique ID. Unload it. + existing_entry = None + + if flow.unique_id is not None: + # Abort all flows in progress with same unique ID. + for progress_flow in self.async_progress(): + if ( + progress_flow["handler"] == flow.handler + and progress_flow["flow_id"] != flow.flow_id + and progress_flow["context"].get("unique_id") == flow.unique_id + ): + self.async_abort(progress_flow["flow_id"]) + + # Find existing entry. + for check_entry in self.config_entries.async_entries(result["handler"]): + if check_entry.unique_id == flow.unique_id: + existing_entry = check_entry + break + + # Unload the entry before setting up the new one. + # We will remove it only after the other one is set up, + # so that device customizations are not getting lost. + if ( + existing_entry is not None + and existing_entry.state not in UNRECOVERABLE_STATES + ): + await self.config_entries.async_unload(existing_entry.entry_id) + + entry = ConfigEntry( + version=result["version"], + domain=result["handler"], + title=result["title"], + data=result["data"], + options={}, + system_options={}, + source=flow.context["source"], + connection_class=flow.CONNECTION_CLASS, + unique_id=flow.unique_id, + ) + + await self.config_entries.async_add(entry) + + if existing_entry is not None: + await self.config_entries.async_remove(existing_entry.entry_id) + + result["result"] = entry + return result + + async def async_create_flow( + self, handler_key: Any, *, context: Optional[Dict] = None, data: Any = None + ) -> "ConfigFlow": + """Create a flow for specified handler. + + Handler key is the domain of the component that we want to set up. + """ + try: + integration = await loader.async_get_integration(self.hass, handler_key) + except loader.IntegrationNotFound: + _LOGGER.error("Cannot find integration %s", handler_key) + raise data_entry_flow.UnknownHandler + + # Make sure requirements and dependencies of component are resolved + await async_process_deps_reqs(self.hass, self._hass_config, integration) + + try: + integration.get_platform("config_flow") + except ImportError as err: + _LOGGER.error( + "Error occurred loading config flow for integration %s: %s", + handler_key, + err, + ) + raise data_entry_flow.UnknownHandler + + handler = HANDLERS.get(handler_key) + + if handler is None: + raise data_entry_flow.UnknownHandler + + if not context or "source" not in context: + raise KeyError("Context not set or doesn't have a source set") + + source = context["source"] + + # Create notification. + if source in DISCOVERY_SOURCES: + self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED) + self.hass.components.persistent_notification.async_create( + title="New devices discovered", + message=( + "We have discovered new devices on your network. " + "[Check it out](/config/integrations)" + ), + notification_id=DISCOVERY_NOTIFICATION_ID, + ) + + flow = cast(ConfigFlow, handler()) + flow.init_step = source + return flow + + class ConfigEntries: """Manage the configuration entries. @@ -408,9 +539,7 @@ class ConfigEntries: def __init__(self, hass: HomeAssistant, hass_config: dict) -> None: """Initialize the entry manager.""" self.hass = hass - self.flow = data_entry_flow.FlowManager( - hass, self._async_create_flow, self._async_finish_flow - ) + self.flow = ConfigEntriesFlowManager(hass, self, hass_config) self.options = OptionsFlowManager(hass) self._hass_config = hass_config self._entries: List[ConfigEntry] = [] @@ -445,6 +574,12 @@ class ConfigEntries: return list(self._entries) return [entry for entry in self._entries if entry.domain == domain] + async def async_add(self, entry: ConfigEntry) -> None: + """Add and setup an entry.""" + self._entries.append(entry) + await self.async_setup(entry.entry_id) + self._async_schedule_save() + async def async_remove(self, entry_id: str) -> Dict[str, Any]: """Remove an entry.""" entry = self.async_get_entry(entry_id) @@ -630,123 +765,6 @@ class ConfigEntries: return await entry.async_unload(self.hass, integration=integration) - async def _async_finish_flow( - self, flow: "ConfigFlow", result: Dict[str, Any] - ) -> Dict[str, Any]: - """Finish a config flow and add an entry.""" - # Remove notification if no other discovery config entries in progress - if not any( - ent["context"]["source"] in DISCOVERY_SOURCES - for ent in self.hass.config_entries.flow.async_progress() - if ent["flow_id"] != flow.flow_id - ): - self.hass.components.persistent_notification.async_dismiss( - DISCOVERY_NOTIFICATION_ID - ) - - if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: - return result - - # Check if config entry exists with unique ID. Unload it. - existing_entry = None - - if flow.unique_id is not None: - # Abort all flows in progress with same unique ID. - for progress_flow in self.flow.async_progress(): - if ( - progress_flow["handler"] == flow.handler - and progress_flow["flow_id"] != flow.flow_id - and progress_flow["context"].get("unique_id") == flow.unique_id - ): - self.flow.async_abort(progress_flow["flow_id"]) - - # Find existing entry. - for check_entry in self.async_entries(result["handler"]): - if check_entry.unique_id == flow.unique_id: - existing_entry = check_entry - break - - # Unload the entry before setting up the new one. - # We will remove it only after the other one is set up, - # so that device customizations are not getting lost. - if ( - existing_entry is not None - and existing_entry.state not in UNRECOVERABLE_STATES - ): - await self.async_unload(existing_entry.entry_id) - - entry = ConfigEntry( - version=result["version"], - domain=result["handler"], - title=result["title"], - data=result["data"], - options={}, - system_options={}, - source=flow.context["source"], - connection_class=flow.CONNECTION_CLASS, - unique_id=flow.unique_id, - ) - self._entries.append(entry) - - await self.async_setup(entry.entry_id) - - if existing_entry is not None: - await self.async_remove(existing_entry.entry_id) - - self._async_schedule_save() - - result["result"] = entry - return result - - async def _async_create_flow( - self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any] - ) -> "ConfigFlow": - """Create a flow for specified handler. - - Handler key is the domain of the component that we want to set up. - """ - try: - integration = await loader.async_get_integration(self.hass, handler_key) - except loader.IntegrationNotFound: - _LOGGER.error("Cannot find integration %s", handler_key) - raise data_entry_flow.UnknownHandler - - # Make sure requirements and dependencies of component are resolved - await async_process_deps_reqs(self.hass, self._hass_config, integration) - - try: - integration.get_platform("config_flow") - except ImportError as err: - _LOGGER.error( - "Error occurred loading config flow for integration %s: %s", - handler_key, - err, - ) - raise data_entry_flow.UnknownHandler - - handler = HANDLERS.get(handler_key) - - if handler is None: - raise data_entry_flow.UnknownHandler - - source = context["source"] - - # Create notification. - if source in DISCOVERY_SOURCES: - self.hass.bus.async_fire(EVENT_FLOW_DISCOVERED) - self.hass.components.persistent_notification.async_create( - title="New devices discovered", - message=( - "We have discovered new devices on your network. " - "[Check it out](/config/integrations)" - ), - notification_id=DISCOVERY_NOTIFICATION_ID, - ) - - flow = cast(ConfigFlow, handler()) - flow.init_step = source - return flow - def _async_schedule_save(self) -> None: """Save the entity registry to a file.""" self._store.async_delay_save(self._data_to_save, SAVE_DELAY) @@ -854,26 +872,23 @@ class ConfigFlow(data_entry_flow.FlowHandler): return self.async_abort(reason="not_implemented") -class OptionsFlowManager: +class OptionsFlowManager(data_entry_flow.FlowManager): """Flow to set options for a configuration entry.""" - def __init__(self, hass: HomeAssistant) -> None: - """Initialize the options manager.""" - self.hass = hass - self.flow = data_entry_flow.FlowManager( - hass, self._async_create_flow, self._async_finish_flow - ) - - async def _async_create_flow( - self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any] - ) -> Optional["OptionsFlow"]: + async def async_create_flow( + self, + handler_key: Any, + *, + context: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> "OptionsFlow": """Create an options flow for a config entry. Entry_id and flow.handler is the same thing to map entry with flow. """ - entry = self.hass.config_entries.async_get_entry(entry_id) + entry = self.hass.config_entries.async_get_entry(handler_key) if entry is None: - return None + raise UnknownEntry(handler_key) if entry.domain not in HANDLERS: raise data_entry_flow.UnknownHandler @@ -881,16 +896,18 @@ class OptionsFlowManager: flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry)) return flow - async def _async_finish_flow( - self, flow: "OptionsFlow", result: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: + async def async_finish_flow( + self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any] + ) -> Dict[str, Any]: """Finish an options flow and update options for configuration entry. Flow.handler and entry_id is the same thing to map flow with entry. """ + flow = cast(OptionsFlow, flow) + entry = self.hass.config_entries.async_get_entry(flow.handler) if entry is None: - return None + raise UnknownEntry(flow.handler) self.hass.config_entries.async_update_entry(entry, options=result["data"]) result["result"] = True diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 7c2b4ab6ddc..6a9f5b1dc5a 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -1,6 +1,7 @@ """Classes to help gather user submissions.""" +import abc import logging -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, cast import uuid import voluptuous as vol @@ -46,20 +47,34 @@ class AbortFlow(FlowError): self.description_placeholders = description_placeholders -class FlowManager: +class FlowManager(abc.ABC): """Manage all the flows that are in progress.""" - def __init__( - self, - hass: HomeAssistant, - async_create_flow: Callable, - async_finish_flow: Callable, - ) -> None: + def __init__(self, hass: HomeAssistant,) -> None: """Initialize the flow manager.""" self.hass = hass self._progress: Dict[str, Any] = {} - self._async_create_flow = async_create_flow - self._async_finish_flow = async_finish_flow + + @abc.abstractmethod + async def async_create_flow( + self, + handler_key: Any, + *, + context: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> "FlowHandler": + """Create a flow for specified handler. + + Handler key is the domain of the component that we want to set up. + """ + pass + + @abc.abstractmethod + async def async_finish_flow( + self, flow: "FlowHandler", result: Dict[str, Any] + ) -> Dict[str, Any]: + """Finish a config flow and add an entry.""" + pass @callback def async_progress(self) -> List[Dict]: @@ -75,7 +90,9 @@ class FlowManager: """Start a configuration flow.""" if context is None: context = {} - flow = await self._async_create_flow(handler, context=context, data=data) + flow = await self.async_create_flow(handler, context=context, data=data) + if not flow: + raise UnknownFlow("Flow was not created") flow.hass = self.hass flow.handler = handler flow.flow_id = uuid.uuid4().hex @@ -168,7 +185,7 @@ class FlowManager: return result # We pass a copy of the result because we're mutating our version - result = await self._async_finish_flow(flow, dict(result)) + result = await self.async_finish_flow(flow, dict(result)) # _async_finish_flow may change result type, check it again if result["type"] == RESULT_TYPE_FORM: diff --git a/tests/components/deconz/test_config_flow.py b/tests/components/deconz/test_config_flow.py index f8fe42d10d8..3fef85611c8 100644 --- a/tests/components/deconz/test_config_flow.py +++ b/tests/components/deconz/test_config_flow.py @@ -436,7 +436,7 @@ async def test_option_flow(hass): entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None) hass.config_entries._entries.append(entry) - flow = await hass.config_entries.options._async_create_flow( + flow = await hass.config_entries.options.async_create_flow( entry.entry_id, context={"source": "test"}, data=None ) diff --git a/tests/components/opentherm_gw/test_config_flow.py b/tests/components/opentherm_gw/test_config_flow.py index 26048543a22..0adcdb188d0 100644 --- a/tests/components/opentherm_gw/test_config_flow.py +++ b/tests/components/opentherm_gw/test_config_flow.py @@ -182,13 +182,13 @@ async def test_options_form(hass): ) entry.add_to_hass(hass) - result = await hass.config_entries.options.flow.async_init( + result = await hass.config_entries.options.async_init( entry.entry_id, context={"source": "test"}, data=None ) assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "init" - result = await hass.config_entries.options.flow.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={CONF_FLOOR_TEMP: True, CONF_PRECISION: PRECISION_HALVES}, ) @@ -197,11 +197,11 @@ async def test_options_form(hass): assert result["data"][CONF_PRECISION] == PRECISION_HALVES assert result["data"][CONF_FLOOR_TEMP] is True - result = await hass.config_entries.options.flow.async_init( + result = await hass.config_entries.options.async_init( entry.entry_id, context={"source": "test"}, data=None ) - result = await hass.config_entries.options.flow.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={CONF_PRECISION: 0} ) diff --git a/tests/components/plex/test_config_flow.py b/tests/components/plex/test_config_flow.py index 0fb1f850809..8f9342c4f72 100644 --- a/tests/components/plex/test_config_flow.py +++ b/tests/components/plex/test_config_flow.py @@ -462,13 +462,13 @@ async def test_option_flow(hass): entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=DEFAULT_OPTIONS) entry.add_to_hass(hass) - result = await hass.config_entries.options.flow.async_init( + result = await hass.config_entries.options.async_init( entry.entry_id, context={"source": "test"}, data=None ) assert result["type"] == "form" assert result["step_id"] == "plex_mp_settings" - result = await hass.config_entries.options.flow.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ config_flow.CONF_USE_EPISODE_ART: True, diff --git a/tests/components/tesla/test_config_flow.py b/tests/components/tesla/test_config_flow.py index b6eeff54a50..7b7e822ce58 100644 --- a/tests/components/tesla/test_config_flow.py +++ b/tests/components/tesla/test_config_flow.py @@ -131,12 +131,12 @@ async def test_option_flow(hass): entry = MockConfigEntry(domain=DOMAIN, data={}, options=None) entry.add_to_hass(hass) - result = await hass.config_entries.options.flow.async_init(entry.entry_id) + result = await hass.config_entries.options.async_init(entry.entry_id) assert result["type"] == "form" assert result["step_id"] == "init" - result = await hass.config_entries.options.flow.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={CONF_SCAN_INTERVAL: 350} ) assert result["type"] == "create_entry" @@ -148,12 +148,12 @@ async def test_option_flow_input_floor(hass): entry = MockConfigEntry(domain=DOMAIN, data={}, options=None) entry.add_to_hass(hass) - result = await hass.config_entries.options.flow.async_init(entry.entry_id) + result = await hass.config_entries.options.async_init(entry.entry_id) assert result["type"] == "form" assert result["step_id"] == "init" - result = await hass.config_entries.options.flow.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={CONF_SCAN_INTERVAL: 1} ) assert result["type"] == "create_entry" diff --git a/tests/components/unifi/test_config_flow.py b/tests/components/unifi/test_config_flow.py index 1b973aee9a5..cc8896d55ce 100644 --- a/tests/components/unifi/test_config_flow.py +++ b/tests/components/unifi/test_config_flow.py @@ -231,7 +231,7 @@ async def test_option_flow(hass): entry = MockConfigEntry(domain=config_flow.DOMAIN, data={}, options=None) hass.config_entries._entries.append(entry) - flow = await hass.config_entries.options._async_create_flow( + flow = await hass.config_entries.options.async_create_flow( entry.entry_id, context={"source": "test"}, data=None ) diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 5b694b2de87..c3a87bcf3a0 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -692,13 +692,13 @@ async def test_entry_options(hass, manager): return OptionsFlowHandler() config_entries.HANDLERS["test"] = TestFlow() - flow = await manager.options._async_create_flow( + flow = await manager.options.async_create_flow( entry.entry_id, context={"source": "test"}, data=None ) flow.handler = entry.entry_id # Used to keep reference to config entry - await manager.options._async_finish_flow(flow, {"data": {"second": True}}) + await manager.options.async_finish_flow(flow, {"data": {"second": True}}) assert entry.data == {"first": True} diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index a6bdd2b5cb6..664304c9ef6 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -14,27 +14,32 @@ def manager(): handlers = Registry() entries = [] - async def async_create_flow(handler_name, *, context, data): - handler = handlers.get(handler_name) + class FlowManager(data_entry_flow.FlowManager): + """Test flow manager.""" - if handler is None: - raise data_entry_flow.UnknownHandler + async def async_create_flow(self, handler_key, *, context, data): + """Test create flow.""" + handler = handlers.get(handler_key) - flow = handler() - flow.init_step = context.get("init_step", "init") - flow.source = context.get("source") - return flow + if handler is None: + raise data_entry_flow.UnknownHandler - async def async_add_entry(flow, result): - if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: - result["source"] = flow.context.get("source") - entries.append(result) - return result + flow = handler() + flow.init_step = context.get("init_step", "init") + flow.source = context.get("source") + return flow - manager = data_entry_flow.FlowManager(None, async_create_flow, async_add_entry) - manager.mock_created_entries = entries - manager.mock_reg_handler = handlers.register - return manager + async def async_finish_flow(self, flow, result): + """Test finish flow.""" + if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + result["source"] = flow.context.get("source") + entries.append(result) + return result + + mgr = FlowManager(None) + mgr.mock_created_entries = entries + mgr.mock_reg_handler = handlers.register + return mgr async def test_configure_reuses_handler_instance(manager): @@ -194,22 +199,23 @@ async def test_finish_callback_change_result_type(hass): step_id="init", data_schema=vol.Schema({"count": int}) ) - async def async_create_flow(handler_name, *, context, data): - """Create a test flow.""" - return TestFlow() + class FlowManager(data_entry_flow.FlowManager): + async def async_create_flow(self, handler_name, *, context, data): + """Create a test flow.""" + return TestFlow() - async def async_finish_flow(flow, result): - """Redirect to init form if count <= 1.""" - if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: - if result["data"] is None or result["data"].get("count", 0) <= 1: - return flow.async_show_form( - step_id="init", data_schema=vol.Schema({"count": int}) - ) - else: - result["result"] = result["data"]["count"] - return result + async def async_finish_flow(self, flow, result): + """Redirect to init form if count <= 1.""" + if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + if result["data"] is None or result["data"].get("count", 0) <= 1: + return flow.async_show_form( + step_id="init", data_schema=vol.Schema({"count": int}) + ) + else: + result["result"] = result["data"]["count"] + return result - manager = data_entry_flow.FlowManager(hass, async_create_flow, async_finish_flow) + manager = FlowManager(hass) result = await manager.async_init("test") assert result["type"] == data_entry_flow.RESULT_TYPE_FORM