Add context to login flow (#15914)

* Add context to login flow

* source -> context

* Fix unit test

* Update comment
This commit is contained in:
Jason Hu 2018-08-13 02:27:18 -07:00 committed by Paulus Schoutsen
parent 45f12dd3c7
commit 50daef9a52
15 changed files with 36 additions and 38 deletions

View File

@ -215,9 +215,9 @@ class AuthManager:
"""Create a login flow."""
auth_provider = self._providers[handler]
return await auth_provider.async_credential_flow()
return await auth_provider.async_credential_flow(context)
async def _async_finish_login_flow(self, result):
async def _async_finish_login_flow(self, context, result):
"""Result of a credential login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None

View File

@ -123,7 +123,7 @@ class AuthProvider:
# Implement by extending class
async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError

View File

@ -158,7 +158,7 @@ class HassAuthProvider(AuthProvider):
self.data = Data(self.hass)
await self.data.async_load()
async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return a flow to login."""
return LoginFlow(self)

View File

@ -31,7 +31,7 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return a flow to login."""
return LoginFlow(self)

View File

@ -36,7 +36,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Legacy API Password'
async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return a flow to login."""
return LoginFlow(self)

View File

@ -54,7 +54,6 @@ have type "create_entry" and "result" key will contain an authorization code.
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
"handler": ["insecure_example", null],
"result": "411ee2f916e648d691e937ae9344681e",
"source": "user",
"title": "Example",
"type": "create_entry",
"version": 1
@ -152,7 +151,7 @@ class LoginFlowIndexView(HomeAssistantView):
handler = data['handler']
try:
result = await self._flow_mgr.async_init(handler)
result = await self._flow_mgr.async_init(handler, context={})
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except data_entry_flow.UnknownStep:

View File

@ -96,7 +96,7 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
return self.json([
flw for flw in hass.config_entries.flow.async_progress()
if flw['source'] != config_entries.SOURCE_USER])
if flw['context']['source'] != config_entries.SOURCE_USER])
class ConfigManagerFlowResourceView(FlowManagerResourceView):

View File

@ -372,10 +372,10 @@ class ConfigEntries:
return await entry.async_unload(
self.hass, component=getattr(self.hass.components, component))
async def _async_finish_flow(self, result):
async def _async_finish_flow(self, context, result):
"""Finish a config flow and add an entry."""
# If no discovery config entries in progress, remove notification.
if not any(ent['source'] in DISCOVERY_SOURCES for ent
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
in self.hass.config_entries.flow.async_progress()):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID)
@ -383,15 +383,12 @@ class ConfigEntries:
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
source = result['source']
if source is None:
source = SOURCE_USER
entry = ConfigEntry(
version=result['version'],
domain=result['handler'],
title=result['title'],
data=result['data'],
source=source,
source=context['source'],
)
self._entries.append(entry)
await self._async_schedule_save()
@ -406,7 +403,7 @@ class ConfigEntries:
self.hass, entry.domain, self._hass_config)
# Return Entry if they not from a discovery request
if result['source'] not in DISCOVERY_SOURCES:
if context['source'] not in DISCOVERY_SOURCES:
return entry
return entry
@ -422,10 +419,7 @@ class ConfigEntries:
if handler is None:
raise data_entry_flow.UnknownHandler
if context is not None:
source = context.get('source', SOURCE_USER)
else:
source = SOURCE_USER
source = context['source']
# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(
@ -442,7 +436,6 @@ class ConfigEntries:
)
flow = handler()
flow.source = source
flow.init_step = source
return flow

View File

@ -46,7 +46,7 @@ class FlowManager:
return [{
'flow_id': flow.flow_id,
'handler': flow.handler,
'source': flow.source,
'context': flow.context,
} for flow in self._progress.values()]
async def async_init(self, handler: Hashable, *, context: Dict = None,
@ -57,6 +57,7 @@ class FlowManager:
flow.hass = self.hass
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.context = context
self._progress[flow.flow_id] = flow
return await self._async_handle_step(flow, flow.init_step, data)
@ -108,7 +109,7 @@ class FlowManager:
self._progress.pop(flow.flow_id)
# We pass a copy of the result because we're mutating our version
entry = await self._async_finish_flow(dict(result))
entry = await self._async_finish_flow(flow.context, dict(result))
if result['type'] == RESULT_TYPE_CREATE_ENTRY:
result['result'] = entry
@ -122,8 +123,8 @@ class FlowHandler:
flow_id = None
hass = None
handler = None
source = None
cur_step = None
context = None
# Set by _async_create_flow callback
init_step = 'init'
@ -156,7 +157,6 @@ class FlowHandler:
'handler': self.handler,
'title': title,
'data': data,
'source': self.source,
}
@callback

View File

@ -1,7 +1,7 @@
"""Tests for the Cast config flow."""
from unittest.mock import patch
from homeassistant import data_entry_flow
from homeassistant import config_entries, data_entry_flow
from homeassistant.setup import async_setup_component
from homeassistant.components import cast
@ -15,7 +15,8 @@ async def test_creating_entry_sets_up_media_player(hass):
MockDependency('pychromecast', 'discovery'), \
patch('pychromecast.discovery.discover_chromecasts',
return_value=True):
result = await hass.config_entries.flow.async_init(cast.DOMAIN)
result = await hass.config_entries.flow.async_init(
cast.DOMAIN, context={'source': config_entries.SOURCE_USER})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
await hass.async_block_till_done()

View File

@ -202,7 +202,6 @@ def test_create_account(hass, client):
'handler': 'test',
'title': 'Test Entry',
'type': 'create_entry',
'source': 'user',
'version': 1,
}
@ -264,7 +263,6 @@ def test_two_step_flow(hass, client):
'type': 'create_entry',
'title': 'user-title',
'version': 1,
'source': 'user',
}
@ -295,7 +293,7 @@ def test_get_progress_index(hass, client):
{
'flow_id': form['flow_id'],
'handler': 'test',
'source': 'hassio'
'context': {'source': 'hassio'}
}
]

View File

@ -1,7 +1,7 @@
"""Tests for the Sonos config flow."""
from unittest.mock import patch
from homeassistant import data_entry_flow
from homeassistant import config_entries, data_entry_flow
from homeassistant.setup import async_setup_component
from homeassistant.components import sonos
@ -13,7 +13,8 @@ async def test_creating_entry_sets_up_media_player(hass):
with patch('homeassistant.components.media_player.sonos.async_setup_entry',
return_value=mock_coro(True)) as mock_setup, \
patch('soco.discover', return_value=True):
result = await hass.config_entries.flow.async_init(sonos.DOMAIN)
result = await hass.config_entries.flow.async_init(
sonos.DOMAIN, context={'source': config_entries.SOURCE_USER})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
await hass.async_block_till_done()

View File

@ -109,7 +109,8 @@ async def test_user_init_trumps_discovery(hass, flow_conf):
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
# User starts flow
result = await hass.config_entries.flow.async_init('test', data={})
result = await hass.config_entries.flow.async_init(
'test', context={'source': config_entries.SOURCE_USER}, data={})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
# Discovery flow has been aborted

View File

@ -116,7 +116,8 @@ def test_add_entry_calls_setup_entry(hass, manager):
})
with patch.dict(config_entries.HANDLERS, {'comp': TestFlow, 'beer': 5}):
yield from manager.flow.async_init('comp')
yield from manager.flow.async_init(
'comp', context={'source': config_entries.SOURCE_USER})
yield from hass.async_block_till_done()
assert len(mock_setup_entry.mock_calls) == 1
@ -171,7 +172,8 @@ async def test_saving_and_loading(hass):
)
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
await hass.config_entries.flow.async_init('test')
await hass.config_entries.flow.async_init(
'test', context={'source': config_entries.SOURCE_USER})
class Test2Flow(data_entry_flow.FlowHandler):
VERSION = 3
@ -187,7 +189,8 @@ async def test_saving_and_loading(hass):
with patch('homeassistant.config_entries.HANDLERS.get',
return_value=Test2Flow):
await hass.config_entries.flow.async_init('test')
await hass.config_entries.flow.async_init(
'test', context={'source': config_entries.SOURCE_USER})
# To trigger the call_later
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))

View File

@ -25,8 +25,10 @@ def manager():
if context is not None else 'user_input'
return flow
async def async_add_entry(result):
async def async_add_entry(context, result):
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
result['source'] = context.get('source') \
if context is not None else 'user'
entries.append(result)
manager = data_entry_flow.FlowManager(
@ -168,7 +170,7 @@ async def test_create_saves_data(manager):
assert entry['handler'] == 'test'
assert entry['title'] == 'Test Title'
assert entry['data'] == 'Test Data'
assert entry['source'] == 'user_input'
assert entry['source'] == 'user'
async def test_discovery_init_flow(manager):