mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 02:07:09 +00:00
Get user after login flow finished (#16047)
* Get user after login flow finished * Add optional parameter 'type' to /auth/login_flow * Update __init__.py
This commit is contained in:
parent
b1ba11510b
commit
f84a31871e
@ -2,7 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
from typing import Any, Dict, List, Optional, Tuple, cast, Union
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
@ -257,15 +257,20 @@ class AuthManager:
|
|||||||
|
|
||||||
async def _async_finish_login_flow(
|
async def _async_finish_login_flow(
|
||||||
self, context: Optional[Dict], result: Dict[str, Any]) \
|
self, context: Optional[Dict], result: Dict[str, Any]) \
|
||||||
-> Optional[models.Credentials]:
|
-> Optional[Union[models.User, models.Credentials]]:
|
||||||
"""Result of a credential login flow."""
|
"""Return a user as result of login flow."""
|
||||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
auth_provider = self._providers[result['handler']]
|
auth_provider = self._providers[result['handler']]
|
||||||
return await auth_provider.async_get_or_create_credentials(
|
cred = await auth_provider.async_get_or_create_credentials(
|
||||||
result['data'])
|
result['data'])
|
||||||
|
|
||||||
|
if context is not None and context.get('credential_only'):
|
||||||
|
return cred
|
||||||
|
|
||||||
|
return await self.async_get_or_create_user(cred)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_auth_provider(
|
def _async_get_auth_provider(
|
||||||
self, credentials: models.Credentials) -> Optional[AuthProvider]:
|
self, credentials: models.Credentials) -> Optional[AuthProvider]:
|
||||||
|
@ -51,6 +51,7 @@ from datetime import timedelta
|
|||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.auth.models import User, Credentials
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.http.ban import log_invalid_auth
|
from homeassistant.components.http.ban import log_invalid_auth
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
@ -68,22 +69,25 @@ SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
|||||||
vol.Required('type'): WS_TYPE_CURRENT_USER,
|
vol.Required('type'): WS_TYPE_CURRENT_USER,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
RESULT_TYPE_CREDENTIALS = 'credentials'
|
||||||
|
RESULT_TYPE_USER = 'user'
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Component to allow users to login."""
|
"""Component to allow users to login."""
|
||||||
store_credentials, retrieve_credentials = _create_cred_store()
|
store_result, retrieve_result = _create_auth_code_store()
|
||||||
|
|
||||||
hass.http.register_view(GrantTokenView(retrieve_credentials))
|
hass.http.register_view(GrantTokenView(retrieve_result))
|
||||||
hass.http.register_view(LinkUserView(retrieve_credentials))
|
hass.http.register_view(LinkUserView(retrieve_result))
|
||||||
|
|
||||||
hass.components.websocket_api.async_register_command(
|
hass.components.websocket_api.async_register_command(
|
||||||
WS_TYPE_CURRENT_USER, websocket_current_user,
|
WS_TYPE_CURRENT_USER, websocket_current_user,
|
||||||
SCHEMA_WS_CURRENT_USER
|
SCHEMA_WS_CURRENT_USER
|
||||||
)
|
)
|
||||||
|
|
||||||
await login_flow.async_setup(hass, store_credentials)
|
await login_flow.async_setup(hass, store_result)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -96,9 +100,9 @@ class GrantTokenView(HomeAssistantView):
|
|||||||
requires_auth = False
|
requires_auth = False
|
||||||
cors_allowed = True
|
cors_allowed = True
|
||||||
|
|
||||||
def __init__(self, retrieve_credentials):
|
def __init__(self, retrieve_user):
|
||||||
"""Initialize the grant token view."""
|
"""Initialize the grant token view."""
|
||||||
self._retrieve_credentials = retrieve_credentials
|
self._retrieve_user = retrieve_user
|
||||||
|
|
||||||
@log_invalid_auth
|
@log_invalid_auth
|
||||||
async def post(self, request):
|
async def post(self, request):
|
||||||
@ -134,15 +138,16 @@ class GrantTokenView(HomeAssistantView):
|
|||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
}, status_code=400)
|
}, status_code=400)
|
||||||
|
|
||||||
credentials = self._retrieve_credentials(client_id, code)
|
user = self._retrieve_user(client_id, RESULT_TYPE_USER, code)
|
||||||
|
|
||||||
if credentials is None:
|
if user is None or not isinstance(user, User):
|
||||||
return self.json({
|
return self.json({
|
||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
'error_description': 'Invalid code',
|
'error_description': 'Invalid code',
|
||||||
}, status_code=400)
|
}, status_code=400)
|
||||||
|
|
||||||
user = await hass.auth.async_get_or_create_user(credentials)
|
# refresh user
|
||||||
|
user = await hass.auth.async_get_user(user.id)
|
||||||
|
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
return self.json({
|
return self.json({
|
||||||
@ -220,7 +225,7 @@ class LinkUserView(HomeAssistantView):
|
|||||||
user = request['hass_user']
|
user = request['hass_user']
|
||||||
|
|
||||||
credentials = self._retrieve_credentials(
|
credentials = self._retrieve_credentials(
|
||||||
data['client_id'], data['code'])
|
data['client_id'], RESULT_TYPE_CREDENTIALS, data['code'])
|
||||||
|
|
||||||
if credentials is None:
|
if credentials is None:
|
||||||
return self.json_message('Invalid code', status_code=400)
|
return self.json_message('Invalid code', status_code=400)
|
||||||
@ -230,37 +235,45 @@ class LinkUserView(HomeAssistantView):
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _create_cred_store():
|
def _create_auth_code_store():
|
||||||
"""Create a credential store."""
|
"""Create an in memory store."""
|
||||||
temp_credentials = {}
|
temp_results = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def store_credentials(client_id, credentials):
|
def store_result(client_id, result):
|
||||||
"""Store credentials and return a code to retrieve it."""
|
"""Store flow result and return a code to retrieve it."""
|
||||||
|
if isinstance(result, User):
|
||||||
|
result_type = RESULT_TYPE_USER
|
||||||
|
elif isinstance(result, Credentials):
|
||||||
|
result_type = RESULT_TYPE_CREDENTIALS
|
||||||
|
else:
|
||||||
|
raise ValueError('result has to be either User or Credentials')
|
||||||
|
|
||||||
code = uuid.uuid4().hex
|
code = uuid.uuid4().hex
|
||||||
temp_credentials[(client_id, code)] = (dt_util.utcnow(), credentials)
|
temp_results[(client_id, result_type, code)] = \
|
||||||
|
(dt_util.utcnow(), result_type, result)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def retrieve_credentials(client_id, code):
|
def retrieve_result(client_id, result_type, code):
|
||||||
"""Retrieve credentials."""
|
"""Retrieve flow result."""
|
||||||
key = (client_id, code)
|
key = (client_id, result_type, code)
|
||||||
|
|
||||||
if key not in temp_credentials:
|
if key not in temp_results:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
created, credentials = temp_credentials.pop(key)
|
created, _, result = temp_results.pop(key)
|
||||||
|
|
||||||
# OAuth 4.2.1
|
# OAuth 4.2.1
|
||||||
# The authorization code MUST expire shortly after it is issued to
|
# The authorization code MUST expire shortly after it is issued to
|
||||||
# mitigate the risk of leaks. A maximum authorization code lifetime of
|
# mitigate the risk of leaks. A maximum authorization code lifetime of
|
||||||
# 10 minutes is RECOMMENDED.
|
# 10 minutes is RECOMMENDED.
|
||||||
if dt_util.utcnow() - created < timedelta(minutes=10):
|
if dt_util.utcnow() - created < timedelta(minutes=10):
|
||||||
return credentials
|
return result
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return store_credentials, retrieve_credentials
|
return store_result, retrieve_result
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -22,10 +22,14 @@ Pass in parameter 'client_id' and 'redirect_url' validate by indieauth.
|
|||||||
Pass in parameter 'handler' to specify the auth provider to use. Auth providers
|
Pass in parameter 'handler' to specify the auth provider to use. Auth providers
|
||||||
are identified by type and id.
|
are identified by type and id.
|
||||||
|
|
||||||
|
And optional parameter 'type' has to set as 'link_user' if login flow used for
|
||||||
|
link credential to exist user. Default 'type' is 'authorize'.
|
||||||
|
|
||||||
{
|
{
|
||||||
"client_id": "https://hassbian.local:8123/",
|
"client_id": "https://hassbian.local:8123/",
|
||||||
"handler": ["local_provider", null],
|
"handler": ["local_provider", null],
|
||||||
"redirect_url": "https://hassbian.local:8123/"
|
"redirect_url": "https://hassbian.local:8123/",
|
||||||
|
"type': "authorize"
|
||||||
}
|
}
|
||||||
|
|
||||||
Return value will be a step in a data entry flow. See the docs for data entry
|
Return value will be a step in a data entry flow. See the docs for data entry
|
||||||
@ -49,6 +53,9 @@ flow for details.
|
|||||||
Progress the flow. Most flows will be 1 page, but could optionally add extra
|
Progress the flow. Most flows will be 1 page, but could optionally add extra
|
||||||
login challenges, like TFA. Once the flow has finished, the returned step will
|
login challenges, like TFA. Once the flow has finished, the returned step will
|
||||||
have type "create_entry" and "result" key will contain an authorization code.
|
have type "create_entry" and "result" key will contain an authorization code.
|
||||||
|
The authorization code associated with an authorized user by default, it will
|
||||||
|
associate with an credential if "type" set to "link_user" in
|
||||||
|
"/auth/login_flow"
|
||||||
|
|
||||||
{
|
{
|
||||||
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
|
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
|
||||||
@ -71,12 +78,12 @@ from homeassistant.components.http.view import HomeAssistantView
|
|||||||
from . import indieauth
|
from . import indieauth
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, store_credentials):
|
async def async_setup(hass, store_result):
|
||||||
"""Component to allow users to login."""
|
"""Component to allow users to login."""
|
||||||
hass.http.register_view(AuthProvidersView)
|
hass.http.register_view(AuthProvidersView)
|
||||||
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow))
|
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow))
|
||||||
hass.http.register_view(
|
hass.http.register_view(
|
||||||
LoginFlowResourceView(hass.auth.login_flow, store_credentials))
|
LoginFlowResourceView(hass.auth.login_flow, store_result))
|
||||||
|
|
||||||
|
|
||||||
class AuthProvidersView(HomeAssistantView):
|
class AuthProvidersView(HomeAssistantView):
|
||||||
@ -138,6 +145,7 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||||||
vol.Required('client_id'): str,
|
vol.Required('client_id'): str,
|
||||||
vol.Required('handler'): vol.Any(str, list),
|
vol.Required('handler'): vol.Any(str, list),
|
||||||
vol.Required('redirect_uri'): str,
|
vol.Required('redirect_uri'): str,
|
||||||
|
vol.Optional('type', default='authorize'): str,
|
||||||
}))
|
}))
|
||||||
@log_invalid_auth
|
@log_invalid_auth
|
||||||
async def post(self, request, data):
|
async def post(self, request, data):
|
||||||
@ -153,7 +161,10 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self._flow_mgr.async_init(
|
result = await self._flow_mgr.async_init(
|
||||||
handler, context={'ip_address': request[KEY_REAL_IP]})
|
handler, context={
|
||||||
|
'ip_address': request[KEY_REAL_IP],
|
||||||
|
'credential_only': data.get('type') == 'link_user',
|
||||||
|
})
|
||||||
except data_entry_flow.UnknownHandler:
|
except data_entry_flow.UnknownHandler:
|
||||||
return self.json_message('Invalid handler specified', 404)
|
return self.json_message('Invalid handler specified', 404)
|
||||||
except data_entry_flow.UnknownStep:
|
except data_entry_flow.UnknownStep:
|
||||||
@ -169,10 +180,10 @@ class LoginFlowResourceView(HomeAssistantView):
|
|||||||
name = 'api:auth:login_flow:resource'
|
name = 'api:auth:login_flow:resource'
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
def __init__(self, flow_mgr, store_credentials):
|
def __init__(self, flow_mgr, store_result):
|
||||||
"""Initialize the login flow resource view."""
|
"""Initialize the login flow resource view."""
|
||||||
self._flow_mgr = flow_mgr
|
self._flow_mgr = flow_mgr
|
||||||
self._store_credentials = store_credentials
|
self._store_result = store_result
|
||||||
|
|
||||||
async def get(self, request):
|
async def get(self, request):
|
||||||
"""Do not allow getting status of a flow in progress."""
|
"""Do not allow getting status of a flow in progress."""
|
||||||
@ -212,7 +223,7 @@ class LoginFlowResourceView(HomeAssistantView):
|
|||||||
return self.json(_prepare_result_json(result))
|
return self.json(_prepare_result_json(result))
|
||||||
|
|
||||||
result.pop('data')
|
result.pop('data')
|
||||||
result['result'] = self._store_credentials(client_id, result['result'])
|
result['result'] = self._store_result(client_id, result['result'])
|
||||||
|
|
||||||
return self.json(result)
|
return self.json(result)
|
||||||
|
|
||||||
|
@ -77,8 +77,7 @@ async def test_create_new_user(hass, hass_storage):
|
|||||||
'password': 'test-pass',
|
'password': 'test-pass',
|
||||||
})
|
})
|
||||||
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
credentials = step['result']
|
user = step['result']
|
||||||
user = await manager.async_get_or_create_user(credentials)
|
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.is_owner is False
|
assert user.is_owner is False
|
||||||
assert user.name == 'Test Name'
|
assert user.name == 'Test Name'
|
||||||
@ -134,9 +133,8 @@ async def test_login_as_existing_user(mock_hass):
|
|||||||
'password': 'test-pass',
|
'password': 'test-pass',
|
||||||
})
|
})
|
||||||
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
credentials = step['result']
|
|
||||||
|
|
||||||
user = await manager.async_get_or_create_user(credentials)
|
user = step['result']
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.id == 'mock-user'
|
assert user.id == 'mock-user'
|
||||||
assert user.is_owner is False
|
assert user.is_owner is False
|
||||||
@ -166,16 +164,18 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage):
|
|||||||
'username': 'test-user',
|
'username': 'test-user',
|
||||||
'password': 'test-pass',
|
'password': 'test-pass',
|
||||||
})
|
})
|
||||||
user = await manager.async_get_or_create_user(step['result'])
|
user = step['result']
|
||||||
assert user is not None
|
assert user is not None
|
||||||
|
|
||||||
step = await manager.login_flow.async_init(('insecure_example',
|
step = await manager.login_flow.async_init(
|
||||||
'another-provider'))
|
('insecure_example', 'another-provider'),
|
||||||
|
context={'credential_only': True})
|
||||||
step = await manager.login_flow.async_configure(step['flow_id'], {
|
step = await manager.login_flow.async_configure(step['flow_id'], {
|
||||||
'username': 'another-user',
|
'username': 'another-user',
|
||||||
'password': 'another-password',
|
'password': 'another-password',
|
||||||
})
|
})
|
||||||
await manager.async_link_user(user, step['result'])
|
new_credential = step['result']
|
||||||
|
await manager.async_link_user(user, new_credential)
|
||||||
assert len(user.credentials) == 2
|
assert len(user.credentials) == 2
|
||||||
|
|
||||||
|
|
||||||
@ -197,7 +197,7 @@ async def test_saving_loading(hass, hass_storage):
|
|||||||
'username': 'test-user',
|
'username': 'test-user',
|
||||||
'password': 'test-pass',
|
'password': 'test-pass',
|
||||||
})
|
})
|
||||||
user = await manager.async_get_or_create_user(step['result'])
|
user = step['result']
|
||||||
await manager.async_activate_user(user)
|
await manager.async_activate_user(user)
|
||||||
await manager.async_create_refresh_token(user, CLIENT_ID)
|
await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||||
|
|
||||||
|
@ -3,13 +3,14 @@ from datetime import timedelta
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.auth.models import Credentials
|
from homeassistant.auth.models import Credentials
|
||||||
|
from homeassistant.components.auth import RESULT_TYPE_USER
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
from homeassistant.components import auth
|
from homeassistant.components import auth
|
||||||
|
|
||||||
from . import async_setup_auth
|
from . import async_setup_auth
|
||||||
|
|
||||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI
|
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
||||||
|
|
||||||
|
|
||||||
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||||
@ -74,26 +75,26 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
|||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
def test_credential_store_expiration():
|
def test_auth_code_store_expiration():
|
||||||
"""Test that the credential store will not return expired tokens."""
|
"""Test that the auth code store will not return expired tokens."""
|
||||||
store, retrieve = auth._create_cred_store()
|
store, retrieve = auth._create_auth_code_store()
|
||||||
client_id = 'bla'
|
client_id = 'bla'
|
||||||
credentials = 'creds'
|
user = MockUser(id='mock_user')
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
|
|
||||||
with patch('homeassistant.util.dt.utcnow', return_value=now):
|
with patch('homeassistant.util.dt.utcnow', return_value=now):
|
||||||
code = store(client_id, credentials)
|
code = store(client_id, user)
|
||||||
|
|
||||||
with patch('homeassistant.util.dt.utcnow',
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
return_value=now + timedelta(minutes=10)):
|
return_value=now + timedelta(minutes=10)):
|
||||||
assert retrieve(client_id, code) is None
|
assert retrieve(client_id, RESULT_TYPE_USER, code) is None
|
||||||
|
|
||||||
with patch('homeassistant.util.dt.utcnow', return_value=now):
|
with patch('homeassistant.util.dt.utcnow', return_value=now):
|
||||||
code = store(client_id, credentials)
|
code = store(client_id, user)
|
||||||
|
|
||||||
with patch('homeassistant.util.dt.utcnow',
|
with patch('homeassistant.util.dt.utcnow',
|
||||||
return_value=now + timedelta(minutes=9, seconds=59)):
|
return_value=now + timedelta(minutes=9, seconds=59)):
|
||||||
assert retrieve(client_id, code) == credentials
|
assert retrieve(client_id, RESULT_TYPE_USER, code) == user
|
||||||
|
|
||||||
|
|
||||||
async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
||||||
|
@ -34,6 +34,7 @@ async def async_get_code(hass, aiohttp_client):
|
|||||||
'client_id': CLIENT_ID,
|
'client_id': CLIENT_ID,
|
||||||
'handler': ['insecure_example', '2nd auth'],
|
'handler': ['insecure_example', '2nd auth'],
|
||||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||||
|
'type': 'link_user',
|
||||||
})
|
})
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
step = await resp.json()
|
step = await resp.json()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user