From 00c6f56cc8a10b371107e9ba1cb44df45b017ff5 Mon Sep 17 00:00:00 2001 From: Jason Hu Date: Tue, 21 Aug 2018 10:48:24 -0700 Subject: [PATCH] Allow finish_flow callback to change data entry result type (#16100) * Allow finish_flow callback to change data entry result type * Add unit test --- homeassistant/auth/__init__.py | 19 ++++++----- homeassistant/config_entries.py | 18 +++++------ homeassistant/data_entry_flow.py | 17 ++++++---- tests/test_data_entry_flow.py | 55 +++++++++++++++++++++++++++++--- 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 52240ab78c6..9f5252be67d 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -2,7 +2,7 @@ import asyncio import logging from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, cast, Union +from typing import Any, Dict, List, Optional, Tuple, cast import jwt @@ -256,20 +256,23 @@ class AuthManager: return await auth_provider.async_credential_flow(context) async def _async_finish_login_flow( - self, context: Optional[Dict], result: Dict[str, Any]) \ - -> Optional[Union[models.User, models.Credentials]]: + self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]) \ + -> Dict[str, Any]: """Return a user as result of login flow.""" if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: - return None + return result auth_provider = self._providers[result['handler']] - cred = await auth_provider.async_get_or_create_credentials( + credentials = await auth_provider.async_get_or_create_credentials( result['data']) - if context is not None and context.get('credential_only'): - return cred + if flow.context is not None and flow.context.get('credential_only'): + result['result'] = credentials + return result - return await self.async_get_or_create_user(cred) + user = await self.async_get_or_create_user(credentials) + result['result'] = user + return result @callback def _async_get_auth_provider( diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index ad3ed896dd4..1858937ec82 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -372,23 +372,24 @@ class ConfigEntries: return await entry.async_unload( self.hass, component=getattr(self.hass.components, component)) - async def _async_finish_flow(self, context, result): + async def _async_finish_flow(self, flow, result): """Finish a config flow and add an entry.""" - # If no discovery config entries in progress, remove notification. + # Remove notification if no other discovery config entries in progress if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent - in self.hass.config_entries.flow.async_progress()): + in self.hass.config_entries.flow.async_progress() + if ent['flow_id'] != flow.flow_id): self.hass.components.persistent_notification.async_dismiss( DISCOVERY_NOTIFICATION_ID) if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: - return None + return result entry = ConfigEntry( version=result['version'], domain=result['handler'], title=result['title'], data=result['data'], - source=context['source'], + source=flow.context['source'], ) self._entries.append(entry) self._async_schedule_save() @@ -402,11 +403,8 @@ class ConfigEntries: await async_setup_component( self.hass, entry.domain, self._hass_config) - # Return Entry if they not from a discovery request - if context['source'] not in DISCOVERY_SOURCES: - return entry - - return entry + result['result'] = entry + return result async def _async_create_flow(self, handler_key, *, context, data): """Create a flow for specified handler. diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index d99d70ce2ec..a54c07fc1b8 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -64,7 +64,7 @@ class FlowManager: return await self._async_handle_step(flow, flow.init_step, data) async def async_configure( - self, flow_id: str, user_input: Optional[str] = None) -> Any: + self, flow_id: str, user_input: Optional[Dict] = None) -> Any: """Continue a configuration flow.""" flow = self._progress.get(flow_id) @@ -86,7 +86,7 @@ class FlowManager: raise UnknownFlow async def _async_handle_step(self, flow: Any, step_id: str, - user_input: Optional[str]) -> Dict: + user_input: Optional[Dict]) -> Dict: """Handle a step of a flow.""" method = "async_step_{}".format(step_id) @@ -106,14 +106,17 @@ class FlowManager: flow.cur_step = (result['step_id'], result['data_schema']) return result + # We pass a copy of the result because we're mutating our version + result = await self._async_finish_flow(flow, dict(result)) + + # _async_finish_flow may change result type, check it again + if result['type'] == RESULT_TYPE_FORM: + flow.cur_step = (result['step_id'], result['data_schema']) + return result + # Abort and Success results both finish the flow self._progress.pop(flow.flow_id) - # We pass a copy of the result because we're mutating our version - entry = await self._async_finish_flow(flow.context, dict(result)) - - if result['type'] == RESULT_TYPE_CREATE_ENTRY: - result['result'] = entry return result diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index c5d5bbb50bf..aa8240ff567 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -25,11 +25,12 @@ def manager(): if context is not None else 'user_input' return flow - async def async_add_entry(context, result): - if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY): - result['source'] = context.get('source') \ - if context is not None else 'user' + async def async_add_entry(flow, result): + if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + result['source'] = flow.context.get('source') \ + if flow.context is not None else 'user' entries.append(result) + return result manager = data_entry_flow.FlowManager( None, async_create_flow, async_add_entry) @@ -198,3 +199,49 @@ async def test_discovery_init_flow(manager): assert entry['title'] == 'hello' assert entry['data'] == data assert entry['source'] == 'discovery' + + +async def test_finish_callback_change_result_type(hass): + """Test finish callback can change result type.""" + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 1 + + async def async_step_init(self, input): + """Return init form with one input field 'count'.""" + if input is not None: + return self.async_create_entry(title='init', data=input) + return self.async_show_form( + step_id='init', + data_schema=vol.Schema({'count': int})) + + async def async_create_flow(handler_name, *, context, data): + """Create a test flow.""" + return TestFlow() + + async def async_finish_flow(flow, result): + """Redirect to init form if count <= 1.""" + if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: + if (result['data'] is None or + result['data'].get('count', 0) <= 1): + return flow.async_show_form( + step_id='init', + data_schema=vol.Schema({'count': int})) + else: + result['result'] = result['data']['count'] + return result + + manager = data_entry_flow.FlowManager( + hass, async_create_flow, async_finish_flow) + + result = await manager.async_init('test') + assert result['type'] == data_entry_flow.RESULT_TYPE_FORM + assert result['step_id'] == 'init' + + result = await manager.async_configure(result['flow_id'], {'count': 0}) + assert result['type'] == data_entry_flow.RESULT_TYPE_FORM + assert result['step_id'] == 'init' + assert 'result' not in result + + result = await manager.async_configure(result['flow_id'], {'count': 2}) + assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result['result'] == 2