mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 01:07:10 +00:00
Remove store user as auth result (#60468)
This commit is contained in:
parent
1aadda4b0f
commit
c6ec84d0cf
@ -124,11 +124,7 @@ from aiohttp import web
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.auth import InvalidAuthError
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.auth.models import (
|
from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials
|
||||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
|
||||||
Credentials,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.http.auth import async_sign_path
|
from homeassistant.components.http.auth import async_sign_path
|
||||||
from homeassistant.components.http.ban import log_invalid_auth
|
from homeassistant.components.http.ban import log_invalid_auth
|
||||||
@ -179,15 +175,12 @@ SCHEMA_WS_SIGN_PATH = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
RESULT_TYPE_CREDENTIALS = "credentials"
|
RESULT_TYPE_CREDENTIALS = "credentials"
|
||||||
RESULT_TYPE_USER = "user"
|
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def create_auth_code(
|
def create_auth_code(hass, client_id: str, credential: Credentials) -> str:
|
||||||
hass, client_id: str, credential_or_user: Credentials | User
|
|
||||||
) -> str:
|
|
||||||
"""Create an authorization code to fetch tokens."""
|
"""Create an authorization code to fetch tokens."""
|
||||||
return hass.data[DOMAIN](client_id, credential_or_user)
|
return hass.data[DOMAIN](client_id, credential)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
@ -296,7 +289,7 @@ class TokenView(HomeAssistantView):
|
|||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
credential = self._retrieve_auth(client_id, RESULT_TYPE_CREDENTIALS, code)
|
credential = self._retrieve_auth(client_id, code)
|
||||||
|
|
||||||
if credential is None or not isinstance(credential, Credentials):
|
if credential is None or not isinstance(credential, Credentials):
|
||||||
return self.json(
|
return self.json(
|
||||||
@ -399,9 +392,7 @@ class LinkUserView(HomeAssistantView):
|
|||||||
hass = request.app["hass"]
|
hass = request.app["hass"]
|
||||||
user = request["hass_user"]
|
user = request["hass_user"]
|
||||||
|
|
||||||
credentials = self._retrieve_credentials(
|
credentials = self._retrieve_credentials(data["client_id"], data["code"])
|
||||||
data["client_id"], RESULT_TYPE_CREDENTIALS, data["code"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if credentials is None:
|
if credentials is None:
|
||||||
return self.json_message("Invalid code", status_code=HTTPStatus.BAD_REQUEST)
|
return self.json_message("Invalid code", status_code=HTTPStatus.BAD_REQUEST)
|
||||||
@ -426,30 +417,25 @@ def _create_auth_code_store():
|
|||||||
@callback
|
@callback
|
||||||
def store_result(client_id, result):
|
def store_result(client_id, result):
|
||||||
"""Store flow result and return a code to retrieve it."""
|
"""Store flow result and return a code to retrieve it."""
|
||||||
if isinstance(result, User):
|
if not isinstance(result, Credentials):
|
||||||
result_type = RESULT_TYPE_USER
|
raise ValueError("result has to be a Credentials instance")
|
||||||
elif isinstance(result, Credentials):
|
|
||||||
result_type = RESULT_TYPE_CREDENTIALS
|
|
||||||
else:
|
|
||||||
raise ValueError("result has to be either User or Credentials")
|
|
||||||
|
|
||||||
code = uuid.uuid4().hex
|
code = uuid.uuid4().hex
|
||||||
temp_results[(client_id, result_type, code)] = (
|
temp_results[(client_id, code)] = (
|
||||||
dt_util.utcnow(),
|
dt_util.utcnow(),
|
||||||
result_type,
|
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def retrieve_result(client_id, result_type, code):
|
def retrieve_result(client_id, code):
|
||||||
"""Retrieve flow result."""
|
"""Retrieve flow result."""
|
||||||
key = (client_id, result_type, code)
|
key = (client_id, code)
|
||||||
|
|
||||||
if key not in temp_results:
|
if key not in temp_results:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
created, _, result = temp_results.pop(key)
|
created, result = temp_results.pop(key)
|
||||||
|
|
||||||
# OAuth 4.2.1
|
# OAuth 4.2.1
|
||||||
# The authorization code MUST expire shortly after it is issued to
|
# The authorization code MUST expire shortly after it is issued to
|
||||||
|
@ -3,10 +3,11 @@ from datetime import timedelta
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.auth import InvalidAuthError
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.auth.models import Credentials
|
from homeassistant.auth.models import Credentials
|
||||||
from homeassistant.components import auth
|
from homeassistant.components import auth
|
||||||
from homeassistant.components.auth import RESULT_TYPE_USER
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
@ -15,6 +16,18 @@ from . import async_setup_auth
|
|||||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_credential():
|
||||||
|
"""Return a mock credential."""
|
||||||
|
return Credentials(
|
||||||
|
id="mock-credential-id",
|
||||||
|
auth_provider_type="insecure_example",
|
||||||
|
auth_provider_id=None,
|
||||||
|
data={"username": "test-user"},
|
||||||
|
is_new=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_user_refresh_token(hass):
|
async def async_setup_user_refresh_token(hass):
|
||||||
"""Create a testing user with a connected credential."""
|
"""Create a testing user with a connected credential."""
|
||||||
user = await hass.auth.async_create_user("Test User")
|
user = await hass.auth.async_create_user("Test User")
|
||||||
@ -96,29 +109,38 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
|||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
|
|
||||||
|
|
||||||
def test_auth_code_store_expiration():
|
def test_auth_code_store_expiration(mock_credential):
|
||||||
"""Test that the auth code store will not return expired tokens."""
|
"""Test that the auth code store will not return expired tokens."""
|
||||||
store, retrieve = auth._create_auth_code_store()
|
store, retrieve = auth._create_auth_code_store()
|
||||||
client_id = "bla"
|
client_id = "bla"
|
||||||
user = MockUser(id="mock_user")
|
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
|
|
||||||
with patch("homeassistant.util.dt.utcnow", return_value=now):
|
with patch("homeassistant.util.dt.utcnow", return_value=now):
|
||||||
code = store(client_id, user)
|
code = store(client_id, mock_credential)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.util.dt.utcnow", return_value=now + timedelta(minutes=10)
|
"homeassistant.util.dt.utcnow", return_value=now + timedelta(minutes=10)
|
||||||
):
|
):
|
||||||
assert retrieve(client_id, RESULT_TYPE_USER, code) is None
|
assert retrieve(client_id, code) is None
|
||||||
|
|
||||||
with patch("homeassistant.util.dt.utcnow", return_value=now):
|
with patch("homeassistant.util.dt.utcnow", return_value=now):
|
||||||
code = store(client_id, user)
|
code = store(client_id, mock_credential)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.util.dt.utcnow",
|
"homeassistant.util.dt.utcnow",
|
||||||
return_value=now + timedelta(minutes=9, seconds=59),
|
return_value=now + timedelta(minutes=9, seconds=59),
|
||||||
):
|
):
|
||||||
assert retrieve(client_id, RESULT_TYPE_USER, code) == user
|
assert retrieve(client_id, code) == mock_credential
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_code_store_requires_credentials(mock_credential):
|
||||||
|
"""Test we require credentials."""
|
||||||
|
store, _retrieve = auth._create_auth_code_store()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
store(None, MockUser())
|
||||||
|
|
||||||
|
store(None, mock_credential)
|
||||||
|
|
||||||
|
|
||||||
async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user