mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Add context to login flow (#15914)
* Add context to login flow * source -> context * Fix unit test * Update comment
This commit is contained in:
parent
45f12dd3c7
commit
50daef9a52
@ -215,9 +215,9 @@ class AuthManager:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
return await auth_provider.async_credential_flow()
|
||||
return await auth_provider.async_credential_flow(context)
|
||||
|
||||
async def _async_finish_login_flow(self, result):
|
||||
async def _async_finish_login_flow(self, context, result):
|
||||
"""Result of a credential login flow."""
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
|
@ -123,7 +123,7 @@ class AuthProvider:
|
||||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_credential_flow(self):
|
||||
async def async_credential_flow(self, context):
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -158,7 +158,7 @@ class HassAuthProvider(AuthProvider):
|
||||
self.data = Data(self.hass)
|
||||
await self.data.async_load()
|
||||
|
||||
async def async_credential_flow(self):
|
||||
async def async_credential_flow(self, context):
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
|
@ -31,7 +31,7 @@ class InvalidAuthError(HomeAssistantError):
|
||||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_credential_flow(self):
|
||||
async def async_credential_flow(self, context):
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
|
@ -36,7 +36,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||
|
||||
DEFAULT_TITLE = 'Legacy API Password'
|
||||
|
||||
async def async_credential_flow(self):
|
||||
async def async_credential_flow(self, context):
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
|
@ -54,7 +54,6 @@ have type "create_entry" and "result" key will contain an authorization code.
|
||||
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
|
||||
"handler": ["insecure_example", null],
|
||||
"result": "411ee2f916e648d691e937ae9344681e",
|
||||
"source": "user",
|
||||
"title": "Example",
|
||||
"type": "create_entry",
|
||||
"version": 1
|
||||
@ -152,7 +151,7 @@ class LoginFlowIndexView(HomeAssistantView):
|
||||
handler = data['handler']
|
||||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(handler)
|
||||
result = await self._flow_mgr.async_init(handler, context={})
|
||||
except data_entry_flow.UnknownHandler:
|
||||
return self.json_message('Invalid handler specified', 404)
|
||||
except data_entry_flow.UnknownStep:
|
||||
|
@ -96,7 +96,7 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
|
||||
|
||||
return self.json([
|
||||
flw for flw in hass.config_entries.flow.async_progress()
|
||||
if flw['source'] != config_entries.SOURCE_USER])
|
||||
if flw['context']['source'] != config_entries.SOURCE_USER])
|
||||
|
||||
|
||||
class ConfigManagerFlowResourceView(FlowManagerResourceView):
|
||||
|
@ -372,10 +372,10 @@ class ConfigEntries:
|
||||
return await entry.async_unload(
|
||||
self.hass, component=getattr(self.hass.components, component))
|
||||
|
||||
async def _async_finish_flow(self, result):
|
||||
async def _async_finish_flow(self, context, result):
|
||||
"""Finish a config flow and add an entry."""
|
||||
# If no discovery config entries in progress, remove notification.
|
||||
if not any(ent['source'] in DISCOVERY_SOURCES for ent
|
||||
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
|
||||
in self.hass.config_entries.flow.async_progress()):
|
||||
self.hass.components.persistent_notification.async_dismiss(
|
||||
DISCOVERY_NOTIFICATION_ID)
|
||||
@ -383,15 +383,12 @@ class ConfigEntries:
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
|
||||
source = result['source']
|
||||
if source is None:
|
||||
source = SOURCE_USER
|
||||
entry = ConfigEntry(
|
||||
version=result['version'],
|
||||
domain=result['handler'],
|
||||
title=result['title'],
|
||||
data=result['data'],
|
||||
source=source,
|
||||
source=context['source'],
|
||||
)
|
||||
self._entries.append(entry)
|
||||
await self._async_schedule_save()
|
||||
@ -406,7 +403,7 @@ class ConfigEntries:
|
||||
self.hass, entry.domain, self._hass_config)
|
||||
|
||||
# Return Entry if they not from a discovery request
|
||||
if result['source'] not in DISCOVERY_SOURCES:
|
||||
if context['source'] not in DISCOVERY_SOURCES:
|
||||
return entry
|
||||
|
||||
return entry
|
||||
@ -422,10 +419,7 @@ class ConfigEntries:
|
||||
if handler is None:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
if context is not None:
|
||||
source = context.get('source', SOURCE_USER)
|
||||
else:
|
||||
source = SOURCE_USER
|
||||
source = context['source']
|
||||
|
||||
# Make sure requirements and dependencies of component are resolved
|
||||
await async_process_deps_reqs(
|
||||
@ -442,7 +436,6 @@ class ConfigEntries:
|
||||
)
|
||||
|
||||
flow = handler()
|
||||
flow.source = source
|
||||
flow.init_step = source
|
||||
return flow
|
||||
|
||||
|
@ -46,7 +46,7 @@ class FlowManager:
|
||||
return [{
|
||||
'flow_id': flow.flow_id,
|
||||
'handler': flow.handler,
|
||||
'source': flow.source,
|
||||
'context': flow.context,
|
||||
} for flow in self._progress.values()]
|
||||
|
||||
async def async_init(self, handler: Hashable, *, context: Dict = None,
|
||||
@ -57,6 +57,7 @@ class FlowManager:
|
||||
flow.hass = self.hass
|
||||
flow.handler = handler
|
||||
flow.flow_id = uuid.uuid4().hex
|
||||
flow.context = context
|
||||
self._progress[flow.flow_id] = flow
|
||||
|
||||
return await self._async_handle_step(flow, flow.init_step, data)
|
||||
@ -108,7 +109,7 @@ class FlowManager:
|
||||
self._progress.pop(flow.flow_id)
|
||||
|
||||
# We pass a copy of the result because we're mutating our version
|
||||
entry = await self._async_finish_flow(dict(result))
|
||||
entry = await self._async_finish_flow(flow.context, dict(result))
|
||||
|
||||
if result['type'] == RESULT_TYPE_CREATE_ENTRY:
|
||||
result['result'] = entry
|
||||
@ -122,8 +123,8 @@ class FlowHandler:
|
||||
flow_id = None
|
||||
hass = None
|
||||
handler = None
|
||||
source = None
|
||||
cur_step = None
|
||||
context = None
|
||||
|
||||
# Set by _async_create_flow callback
|
||||
init_step = 'init'
|
||||
@ -156,7 +157,6 @@ class FlowHandler:
|
||||
'handler': self.handler,
|
||||
'title': title,
|
||||
'data': data,
|
||||
'source': self.source,
|
||||
}
|
||||
|
||||
@callback
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Tests for the Cast config flow."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import cast
|
||||
|
||||
@ -15,7 +15,8 @@ async def test_creating_entry_sets_up_media_player(hass):
|
||||
MockDependency('pychromecast', 'discovery'), \
|
||||
patch('pychromecast.discovery.discover_chromecasts',
|
||||
return_value=True):
|
||||
result = await hass.config_entries.flow.async_init(cast.DOMAIN)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
cast.DOMAIN, context={'source': config_entries.SOURCE_USER})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
@ -202,7 +202,6 @@ def test_create_account(hass, client):
|
||||
'handler': 'test',
|
||||
'title': 'Test Entry',
|
||||
'type': 'create_entry',
|
||||
'source': 'user',
|
||||
'version': 1,
|
||||
}
|
||||
|
||||
@ -264,7 +263,6 @@ def test_two_step_flow(hass, client):
|
||||
'type': 'create_entry',
|
||||
'title': 'user-title',
|
||||
'version': 1,
|
||||
'source': 'user',
|
||||
}
|
||||
|
||||
|
||||
@ -295,7 +293,7 @@ def test_get_progress_index(hass, client):
|
||||
{
|
||||
'flow_id': form['flow_id'],
|
||||
'handler': 'test',
|
||||
'source': 'hassio'
|
||||
'context': {'source': 'hassio'}
|
||||
}
|
||||
]
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Tests for the Sonos config flow."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import sonos
|
||||
|
||||
@ -13,7 +13,8 @@ async def test_creating_entry_sets_up_media_player(hass):
|
||||
with patch('homeassistant.components.media_player.sonos.async_setup_entry',
|
||||
return_value=mock_coro(True)) as mock_setup, \
|
||||
patch('soco.discover', return_value=True):
|
||||
result = await hass.config_entries.flow.async_init(sonos.DOMAIN)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
sonos.DOMAIN, context={'source': config_entries.SOURCE_USER})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
@ -109,7 +109,8 @@ async def test_user_init_trumps_discovery(hass, flow_conf):
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
|
||||
|
||||
# User starts flow
|
||||
result = await hass.config_entries.flow.async_init('test', data={})
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
'test', context={'source': config_entries.SOURCE_USER}, data={})
|
||||
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
# Discovery flow has been aborted
|
||||
|
@ -116,7 +116,8 @@ def test_add_entry_calls_setup_entry(hass, manager):
|
||||
})
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'comp': TestFlow, 'beer': 5}):
|
||||
yield from manager.flow.async_init('comp')
|
||||
yield from manager.flow.async_init(
|
||||
'comp', context={'source': config_entries.SOURCE_USER})
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
@ -171,7 +172,8 @@ async def test_saving_and_loading(hass):
|
||||
)
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
|
||||
await hass.config_entries.flow.async_init('test')
|
||||
await hass.config_entries.flow.async_init(
|
||||
'test', context={'source': config_entries.SOURCE_USER})
|
||||
|
||||
class Test2Flow(data_entry_flow.FlowHandler):
|
||||
VERSION = 3
|
||||
@ -187,7 +189,8 @@ async def test_saving_and_loading(hass):
|
||||
|
||||
with patch('homeassistant.config_entries.HANDLERS.get',
|
||||
return_value=Test2Flow):
|
||||
await hass.config_entries.flow.async_init('test')
|
||||
await hass.config_entries.flow.async_init(
|
||||
'test', context={'source': config_entries.SOURCE_USER})
|
||||
|
||||
# To trigger the call_later
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
|
@ -25,8 +25,10 @@ def manager():
|
||||
if context is not None else 'user_input'
|
||||
return flow
|
||||
|
||||
async def async_add_entry(result):
|
||||
async def async_add_entry(context, result):
|
||||
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
|
||||
result['source'] = context.get('source') \
|
||||
if context is not None else 'user'
|
||||
entries.append(result)
|
||||
|
||||
manager = data_entry_flow.FlowManager(
|
||||
@ -168,7 +170,7 @@ async def test_create_saves_data(manager):
|
||||
assert entry['handler'] == 'test'
|
||||
assert entry['title'] == 'Test Title'
|
||||
assert entry['data'] == 'Test Data'
|
||||
assert entry['source'] == 'user_input'
|
||||
assert entry['source'] == 'user'
|
||||
|
||||
|
||||
async def test_discovery_init_flow(manager):
|
||||
|
Loading…
x
Reference in New Issue
Block a user