Remove store user as auth result (#60468)

This commit is contained in:
Paulus Schoutsen 2021-11-28 05:14:52 -08:00 committed by GitHub
parent 1aadda4b0f
commit c6ec84d0cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 32 deletions

View File

@ -124,11 +124,7 @@ from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.auth import InvalidAuthError from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import ( from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
Credentials,
User,
)
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http.auth import async_sign_path from homeassistant.components.http.auth import async_sign_path
from homeassistant.components.http.ban import log_invalid_auth from homeassistant.components.http.ban import log_invalid_auth
@ -179,15 +175,12 @@ SCHEMA_WS_SIGN_PATH = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
) )
RESULT_TYPE_CREDENTIALS = "credentials" RESULT_TYPE_CREDENTIALS = "credentials"
RESULT_TYPE_USER = "user"
@bind_hass @bind_hass
def create_auth_code( def create_auth_code(hass, client_id: str, credential: Credentials) -> str:
hass, client_id: str, credential_or_user: Credentials | User
) -> str:
"""Create an authorization code to fetch tokens.""" """Create an authorization code to fetch tokens."""
return hass.data[DOMAIN](client_id, credential_or_user) return hass.data[DOMAIN](client_id, credential)
async def async_setup(hass, config): async def async_setup(hass, config):
@ -296,7 +289,7 @@ class TokenView(HomeAssistantView):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
) )
credential = self._retrieve_auth(client_id, RESULT_TYPE_CREDENTIALS, code) credential = self._retrieve_auth(client_id, code)
if credential is None or not isinstance(credential, Credentials): if credential is None or not isinstance(credential, Credentials):
return self.json( return self.json(
@ -399,9 +392,7 @@ class LinkUserView(HomeAssistantView):
hass = request.app["hass"] hass = request.app["hass"]
user = request["hass_user"] user = request["hass_user"]
credentials = self._retrieve_credentials( credentials = self._retrieve_credentials(data["client_id"], data["code"])
data["client_id"], RESULT_TYPE_CREDENTIALS, data["code"]
)
if credentials is None: if credentials is None:
return self.json_message("Invalid code", status_code=HTTPStatus.BAD_REQUEST) return self.json_message("Invalid code", status_code=HTTPStatus.BAD_REQUEST)
@ -426,30 +417,25 @@ def _create_auth_code_store():
@callback @callback
def store_result(client_id, result): def store_result(client_id, result):
"""Store flow result and return a code to retrieve it.""" """Store flow result and return a code to retrieve it."""
if isinstance(result, User): if not isinstance(result, Credentials):
result_type = RESULT_TYPE_USER raise ValueError("result has to be a Credentials instance")
elif isinstance(result, Credentials):
result_type = RESULT_TYPE_CREDENTIALS
else:
raise ValueError("result has to be either User or Credentials")
code = uuid.uuid4().hex code = uuid.uuid4().hex
temp_results[(client_id, result_type, code)] = ( temp_results[(client_id, code)] = (
dt_util.utcnow(), dt_util.utcnow(),
result_type,
result, result,
) )
return code return code
@callback @callback
def retrieve_result(client_id, result_type, code): def retrieve_result(client_id, code):
"""Retrieve flow result.""" """Retrieve flow result."""
key = (client_id, result_type, code) key = (client_id, code)
if key not in temp_results: if key not in temp_results:
return None return None
created, _, result = temp_results.pop(key) created, result = temp_results.pop(key)
# OAuth 4.2.1 # OAuth 4.2.1
# The authorization code MUST expire shortly after it is issued to # The authorization code MUST expire shortly after it is issued to

View File

@ -3,10 +3,11 @@ from datetime import timedelta
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
import pytest
from homeassistant.auth import InvalidAuthError from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import Credentials from homeassistant.auth.models import Credentials
from homeassistant.components import auth from homeassistant.components import auth
from homeassistant.components.auth import RESULT_TYPE_USER
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -15,6 +16,18 @@ from . import async_setup_auth
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
@pytest.fixture
def mock_credential():
"""Return a mock credential."""
return Credentials(
id="mock-credential-id",
auth_provider_type="insecure_example",
auth_provider_id=None,
data={"username": "test-user"},
is_new=False,
)
async def async_setup_user_refresh_token(hass): async def async_setup_user_refresh_token(hass):
"""Create a testing user with a connected credential.""" """Create a testing user with a connected credential."""
user = await hass.auth.async_create_user("Test User") user = await hass.auth.async_create_user("Test User")
@ -96,29 +109,38 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
def test_auth_code_store_expiration(): def test_auth_code_store_expiration(mock_credential):
"""Test that the auth code store will not return expired tokens.""" """Test that the auth code store will not return expired tokens."""
store, retrieve = auth._create_auth_code_store() store, retrieve = auth._create_auth_code_store()
client_id = "bla" client_id = "bla"
user = MockUser(id="mock_user")
now = utcnow() now = utcnow()
with patch("homeassistant.util.dt.utcnow", return_value=now): with patch("homeassistant.util.dt.utcnow", return_value=now):
code = store(client_id, user) code = store(client_id, mock_credential)
with patch( with patch(
"homeassistant.util.dt.utcnow", return_value=now + timedelta(minutes=10) "homeassistant.util.dt.utcnow", return_value=now + timedelta(minutes=10)
): ):
assert retrieve(client_id, RESULT_TYPE_USER, code) is None assert retrieve(client_id, code) is None
with patch("homeassistant.util.dt.utcnow", return_value=now): with patch("homeassistant.util.dt.utcnow", return_value=now):
code = store(client_id, user) code = store(client_id, mock_credential)
with patch( with patch(
"homeassistant.util.dt.utcnow", "homeassistant.util.dt.utcnow",
return_value=now + timedelta(minutes=9, seconds=59), return_value=now + timedelta(minutes=9, seconds=59),
): ):
assert retrieve(client_id, RESULT_TYPE_USER, code) == user assert retrieve(client_id, code) == mock_credential
def test_auth_code_store_requires_credentials(mock_credential):
"""Test we require credentials."""
store, _retrieve = auth._create_auth_code_store()
with pytest.raises(ValueError):
store(None, MockUser())
store(None, mock_credential)
async def test_ws_current_user(hass, hass_ws_client, hass_access_token): async def test_ws_current_user(hass, hass_ws_client, hass_access_token):