diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index b710ca9999e..52240ab78c6 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -2,7 +2,7 @@ import asyncio import logging from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, Union import jwt @@ -257,15 +257,20 @@ class AuthManager: async def _async_finish_login_flow( self, context: Optional[Dict], result: Dict[str, Any]) \ - -> Optional[models.Credentials]: - """Result of a credential login flow.""" + -> Optional[Union[models.User, models.Credentials]]: + """Return a user as result of login flow.""" if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: return None auth_provider = self._providers[result['handler']] - return await auth_provider.async_get_or_create_credentials( + cred = await auth_provider.async_get_or_create_credentials( result['data']) + if context is not None and context.get('credential_only'): + return cred + + return await self.async_get_or_create_user(cred) + @callback def _async_get_auth_provider( self, credentials: models.Credentials) -> Optional[AuthProvider]: diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 102bfe58b55..08bb3e679b8 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -51,6 +51,7 @@ from datetime import timedelta import voluptuous as vol +from homeassistant.auth.models import User, Credentials from homeassistant.components import websocket_api from homeassistant.components.http.ban import log_invalid_auth from homeassistant.components.http.data_validator import RequestDataValidator @@ -68,22 +69,25 @@ SCHEMA_WS_CURRENT_USER = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): WS_TYPE_CURRENT_USER, }) +RESULT_TYPE_CREDENTIALS = 'credentials' +RESULT_TYPE_USER = 'user' + _LOGGER = logging.getLogger(__name__) async def async_setup(hass, config): """Component to allow users to login.""" - store_credentials, retrieve_credentials = _create_cred_store() + store_result, retrieve_result = _create_auth_code_store() - hass.http.register_view(GrantTokenView(retrieve_credentials)) - hass.http.register_view(LinkUserView(retrieve_credentials)) + hass.http.register_view(GrantTokenView(retrieve_result)) + hass.http.register_view(LinkUserView(retrieve_result)) hass.components.websocket_api.async_register_command( WS_TYPE_CURRENT_USER, websocket_current_user, SCHEMA_WS_CURRENT_USER ) - await login_flow.async_setup(hass, store_credentials) + await login_flow.async_setup(hass, store_result) return True @@ -96,9 +100,9 @@ class GrantTokenView(HomeAssistantView): requires_auth = False cors_allowed = True - def __init__(self, retrieve_credentials): + def __init__(self, retrieve_user): """Initialize the grant token view.""" - self._retrieve_credentials = retrieve_credentials + self._retrieve_user = retrieve_user @log_invalid_auth async def post(self, request): @@ -134,15 +138,16 @@ class GrantTokenView(HomeAssistantView): 'error': 'invalid_request', }, status_code=400) - credentials = self._retrieve_credentials(client_id, code) + user = self._retrieve_user(client_id, RESULT_TYPE_USER, code) - if credentials is None: + if user is None or not isinstance(user, User): return self.json({ 'error': 'invalid_request', 'error_description': 'Invalid code', }, status_code=400) - user = await hass.auth.async_get_or_create_user(credentials) + # refresh user + user = await hass.auth.async_get_user(user.id) if not user.is_active: return self.json({ @@ -220,7 +225,7 @@ class LinkUserView(HomeAssistantView): user = request['hass_user'] credentials = self._retrieve_credentials( - data['client_id'], data['code']) + data['client_id'], RESULT_TYPE_CREDENTIALS, data['code']) if credentials is None: return self.json_message('Invalid code', status_code=400) @@ -230,37 +235,45 @@ class LinkUserView(HomeAssistantView): @callback -def _create_cred_store(): - """Create a credential store.""" - temp_credentials = {} +def _create_auth_code_store(): + """Create an in memory store.""" + temp_results = {} @callback - def store_credentials(client_id, credentials): - """Store credentials and return a code to retrieve it.""" + def store_result(client_id, result): + """Store flow result and return a code to retrieve it.""" + if isinstance(result, User): + result_type = RESULT_TYPE_USER + elif isinstance(result, Credentials): + result_type = RESULT_TYPE_CREDENTIALS + else: + raise ValueError('result has to be either User or Credentials') + code = uuid.uuid4().hex - temp_credentials[(client_id, code)] = (dt_util.utcnow(), credentials) + temp_results[(client_id, result_type, code)] = \ + (dt_util.utcnow(), result_type, result) return code @callback - def retrieve_credentials(client_id, code): - """Retrieve credentials.""" - key = (client_id, code) + def retrieve_result(client_id, result_type, code): + """Retrieve flow result.""" + key = (client_id, result_type, code) - if key not in temp_credentials: + if key not in temp_results: return None - created, credentials = temp_credentials.pop(key) + created, _, result = temp_results.pop(key) # OAuth 4.2.1 # The authorization code MUST expire shortly after it is issued to # mitigate the risk of leaks. A maximum authorization code lifetime of # 10 minutes is RECOMMENDED. if dt_util.utcnow() - created < timedelta(minutes=10): - return credentials + return result return None - return store_credentials, retrieve_credentials + return store_result, retrieve_result @callback diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index e1d21bbb632..a518bdde415 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -22,10 +22,14 @@ Pass in parameter 'client_id' and 'redirect_url' validate by indieauth. Pass in parameter 'handler' to specify the auth provider to use. Auth providers are identified by type and id. +And optional parameter 'type' has to set as 'link_user' if login flow used for +link credential to exist user. Default 'type' is 'authorize'. + { "client_id": "https://hassbian.local:8123/", "handler": ["local_provider", null], - "redirect_url": "https://hassbian.local:8123/" + "redirect_url": "https://hassbian.local:8123/", + "type': "authorize" } Return value will be a step in a data entry flow. See the docs for data entry @@ -49,6 +53,9 @@ flow for details. Progress the flow. Most flows will be 1 page, but could optionally add extra login challenges, like TFA. Once the flow has finished, the returned step will have type "create_entry" and "result" key will contain an authorization code. +The authorization code associated with an authorized user by default, it will +associate with an credential if "type" set to "link_user" in +"/auth/login_flow" { "flow_id": "8f7e42faab604bcab7ac43c44ca34d58", @@ -71,12 +78,12 @@ from homeassistant.components.http.view import HomeAssistantView from . import indieauth -async def async_setup(hass, store_credentials): +async def async_setup(hass, store_result): """Component to allow users to login.""" hass.http.register_view(AuthProvidersView) hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow)) hass.http.register_view( - LoginFlowResourceView(hass.auth.login_flow, store_credentials)) + LoginFlowResourceView(hass.auth.login_flow, store_result)) class AuthProvidersView(HomeAssistantView): @@ -138,6 +145,7 @@ class LoginFlowIndexView(HomeAssistantView): vol.Required('client_id'): str, vol.Required('handler'): vol.Any(str, list), vol.Required('redirect_uri'): str, + vol.Optional('type', default='authorize'): str, })) @log_invalid_auth async def post(self, request, data): @@ -153,7 +161,10 @@ class LoginFlowIndexView(HomeAssistantView): try: result = await self._flow_mgr.async_init( - handler, context={'ip_address': request[KEY_REAL_IP]}) + handler, context={ + 'ip_address': request[KEY_REAL_IP], + 'credential_only': data.get('type') == 'link_user', + }) except data_entry_flow.UnknownHandler: return self.json_message('Invalid handler specified', 404) except data_entry_flow.UnknownStep: @@ -169,10 +180,10 @@ class LoginFlowResourceView(HomeAssistantView): name = 'api:auth:login_flow:resource' requires_auth = False - def __init__(self, flow_mgr, store_credentials): + def __init__(self, flow_mgr, store_result): """Initialize the login flow resource view.""" self._flow_mgr = flow_mgr - self._store_credentials = store_credentials + self._store_result = store_result async def get(self, request): """Do not allow getting status of a flow in progress.""" @@ -212,7 +223,7 @@ class LoginFlowResourceView(HomeAssistantView): return self.json(_prepare_result_json(result)) result.pop('data') - result['result'] = self._store_credentials(client_id, result['result']) + result['result'] = self._store_result(client_id, result['result']) return self.json(result) diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index da5daca7cf6..5ea3b528b4e 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -77,8 +77,7 @@ async def test_create_new_user(hass, hass_storage): 'password': 'test-pass', }) assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - credentials = step['result'] - user = await manager.async_get_or_create_user(credentials) + user = step['result'] assert user is not None assert user.is_owner is False assert user.name == 'Test Name' @@ -134,9 +133,8 @@ async def test_login_as_existing_user(mock_hass): 'password': 'test-pass', }) assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - credentials = step['result'] - user = await manager.async_get_or_create_user(credentials) + user = step['result'] assert user is not None assert user.id == 'mock-user' assert user.is_owner is False @@ -166,16 +164,18 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage): 'username': 'test-user', 'password': 'test-pass', }) - user = await manager.async_get_or_create_user(step['result']) + user = step['result'] assert user is not None - step = await manager.login_flow.async_init(('insecure_example', - 'another-provider')) + step = await manager.login_flow.async_init( + ('insecure_example', 'another-provider'), + context={'credential_only': True}) step = await manager.login_flow.async_configure(step['flow_id'], { 'username': 'another-user', 'password': 'another-password', }) - await manager.async_link_user(user, step['result']) + new_credential = step['result'] + await manager.async_link_user(user, new_credential) assert len(user.credentials) == 2 @@ -197,7 +197,7 @@ async def test_saving_loading(hass, hass_storage): 'username': 'test-user', 'password': 'test-pass', }) - user = await manager.async_get_or_create_user(step['result']) + user = step['result'] await manager.async_activate_user(user) await manager.async_create_refresh_token(user, CLIENT_ID) diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index f1a1bb5bd3c..79749da1461 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -3,13 +3,14 @@ from datetime import timedelta from unittest.mock import patch from homeassistant.auth.models import Credentials +from homeassistant.components.auth import RESULT_TYPE_USER from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow from homeassistant.components import auth from . import async_setup_auth -from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI +from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client): @@ -74,26 +75,26 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client): assert resp.status == 200 -def test_credential_store_expiration(): - """Test that the credential store will not return expired tokens.""" - store, retrieve = auth._create_cred_store() +def test_auth_code_store_expiration(): + """Test that the auth code store will not return expired tokens.""" + store, retrieve = auth._create_auth_code_store() client_id = 'bla' - credentials = 'creds' + user = MockUser(id='mock_user') now = utcnow() with patch('homeassistant.util.dt.utcnow', return_value=now): - code = store(client_id, credentials) + code = store(client_id, user) with patch('homeassistant.util.dt.utcnow', return_value=now + timedelta(minutes=10)): - assert retrieve(client_id, code) is None + assert retrieve(client_id, RESULT_TYPE_USER, code) is None with patch('homeassistant.util.dt.utcnow', return_value=now): - code = store(client_id, credentials) + code = store(client_id, user) with patch('homeassistant.util.dt.utcnow', return_value=now + timedelta(minutes=9, seconds=59)): - assert retrieve(client_id, code) == credentials + assert retrieve(client_id, RESULT_TYPE_USER, code) == user async def test_ws_current_user(hass, hass_ws_client, hass_access_token): diff --git a/tests/components/auth/test_init_link_user.py b/tests/components/auth/test_init_link_user.py index e209e0ee856..5166f661491 100644 --- a/tests/components/auth/test_init_link_user.py +++ b/tests/components/auth/test_init_link_user.py @@ -34,6 +34,7 @@ async def async_get_code(hass, aiohttp_client): 'client_id': CLIENT_ID, 'handler': ['insecure_example', '2nd auth'], 'redirect_uri': CLIENT_REDIRECT_URI, + 'type': 'link_user', }) assert resp.status == 200 step = await resp.json()