Add ws endpoint to remove expiration date from refresh tokens (#117546)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Robert Resch 2024-05-29 09:09:59 +02:00 committed by GitHub
parent 7e62061b9a
commit e087abe802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 235 additions and 73 deletions

View File

@ -516,6 +516,13 @@ class AuthManager:
for revoke_callback in callbacks: for revoke_callback in callbacks:
revoke_callback() revoke_callback()
@callback
def async_set_expiry(
self, refresh_token: models.RefreshToken, *, enable_expiry: bool
) -> None:
"""Enable or disable expiry of a refresh token."""
self._store.async_set_expiry(refresh_token, enable_expiry=enable_expiry)
@callback @callback
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None: def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
"""Remove expired refresh tokens.""" """Remove expired refresh tokens."""

View File

@ -6,7 +6,6 @@ from datetime import timedelta
import hmac import hmac
import itertools import itertools
from logging import getLogger from logging import getLogger
import time
from typing import Any from typing import Any
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -282,6 +281,21 @@ class AuthStore:
) )
self._async_schedule_save() self._async_schedule_save()
@callback
def async_set_expiry(
self, refresh_token: models.RefreshToken, *, enable_expiry: bool
) -> None:
"""Enable or disable expiry of a refresh token."""
if enable_expiry:
if refresh_token.expire_at is None:
refresh_token.expire_at = (
refresh_token.last_used_at or dt_util.utcnow()
).timestamp() + REFRESH_TOKEN_EXPIRATION
self._async_schedule_save()
else:
refresh_token.expire_at = None
self._async_schedule_save()
async def async_load(self) -> None: # noqa: C901 async def async_load(self) -> None: # noqa: C901
"""Load the users.""" """Load the users."""
if self._loaded: if self._loaded:
@ -295,8 +309,6 @@ class AuthStore:
perm_lookup = PermissionLookup(ent_reg, dev_reg) perm_lookup = PermissionLookup(ent_reg, dev_reg)
self._perm_lookup = perm_lookup self._perm_lookup = perm_lookup
now_ts = time.time()
if data is None or not isinstance(data, dict): if data is None or not isinstance(data, dict):
self._set_defaults() self._set_defaults()
return return
@ -450,14 +462,6 @@ class AuthStore:
else: else:
last_used_at = None last_used_at = None
if (
expire_at := rt_dict.get("expire_at")
) is None and token_type == models.TOKEN_TYPE_NORMAL:
if last_used_at:
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
token = models.RefreshToken( token = models.RefreshToken(
id=rt_dict["id"], id=rt_dict["id"],
user=users[rt_dict["user_id"]], user=users[rt_dict["user_id"]],
@ -474,7 +478,7 @@ class AuthStore:
jwt_key=rt_dict["jwt_key"], jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at, last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"), last_used_ip=rt_dict.get("last_used_ip"),
expire_at=expire_at, expire_at=rt_dict.get("expire_at"),
version=rt_dict.get("version"), version=rt_dict.get("version"),
) )
if "credential_id" in rt_dict: if "credential_id" in rt_dict:

View File

@ -197,6 +197,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, websocket_delete_refresh_token) websocket_api.async_register_command(hass, websocket_delete_refresh_token)
websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens) websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens)
websocket_api.async_register_command(hass, websocket_sign_path) websocket_api.async_register_command(hass, websocket_sign_path)
websocket_api.async_register_command(hass, websocket_refresh_token_set_expiry)
login_flow.async_setup(hass, store_result) login_flow.async_setup(hass, store_result)
mfa_setup_flow.async_setup(hass) mfa_setup_flow.async_setup(hass)
@ -565,18 +566,23 @@ def websocket_refresh_tokens(
else: else:
auth_provider_type = None auth_provider_type = None
expire_at = None
if refresh.expire_at:
expire_at = dt_util.utc_from_timestamp(refresh.expire_at)
tokens.append( tokens.append(
{ {
"id": refresh.id, "auth_provider_type": auth_provider_type,
"client_icon": refresh.client_icon,
"client_id": refresh.client_id, "client_id": refresh.client_id,
"client_name": refresh.client_name, "client_name": refresh.client_name,
"client_icon": refresh.client_icon,
"type": refresh.token_type,
"created_at": refresh.created_at, "created_at": refresh.created_at,
"expire_at": expire_at,
"id": refresh.id,
"is_current": refresh.id == current_id, "is_current": refresh.id == current_id,
"last_used_at": refresh.last_used_at, "last_used_at": refresh.last_used_at,
"last_used_ip": refresh.last_used_ip, "last_used_ip": refresh.last_used_ip,
"auth_provider_type": auth_provider_type, "type": refresh.token_type,
} }
) )
@ -702,3 +708,26 @@ def websocket_sign_path(
}, },
) )
) )
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/refresh_token_set_expiry",
vol.Required("refresh_token_id"): str,
vol.Required("enable_expiry"): bool,
}
)
@websocket_api.ws_require_user()
def websocket_refresh_token_set_expiry(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle a set expiry of a refresh token request."""
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
if refresh_token is None:
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
return
hass.auth.async_set_expiry(refresh_token, enable_expiry=msg["enable_expiry"])
connection.send_result(msg["id"], {})

View File

@ -1,17 +1,14 @@
"""Tests for the auth store.""" """Tests for the auth store."""
import asyncio import asyncio
from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
from freezegun import freeze_time
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from homeassistant.auth import auth_store from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
MOCK_STORAGE_DATA = { MOCK_STORAGE_DATA = {
"version": 1, "version": 1,
@ -220,68 +217,64 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
assert results[0] == results[1] assert results[0] == results[1]
async def test_add_expire_at_property( async def test_dont_change_expire_at_on_load(
hass: HomeAssistant, hass_storage: dict[str, Any] hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None: ) -> None:
"""Test we correctly add expired_at property if not existing.""" """Test we correctly don't modify expired_at store load."""
now = dt_util.utcnow() hass_storage[auth_store.STORAGE_KEY] = {
with freeze_time(now): "version": 1,
hass_storage[auth_store.STORAGE_KEY] = { "data": {
"version": 1, "credentials": [],
"data": { "users": [
"credentials": [], {
"users": [ "id": "user-id",
{ "is_active": True,
"id": "user-id", "is_owner": True,
"is_active": True, "name": "Paulus",
"is_owner": True, "system_generated": False,
"name": "Paulus", },
"system_generated": False, {
}, "id": "system-id",
{ "is_active": True,
"id": "system-id", "is_owner": True,
"is_active": True, "name": "Hass.io",
"is_owner": True, "system_generated": True,
"name": "Hass.io", },
"system_generated": True, ],
}, "refresh_tokens": [
], {
"refresh_tokens": [ "access_token_expiration": 1800.0,
{ "client_id": "http://localhost:8123/",
"access_token_expiration": 1800.0, "created_at": "2018-10-03T13:43:19.774637+00:00",
"client_id": "http://localhost:8123/", "id": "user-token-id",
"created_at": "2018-10-03T13:43:19.774637+00:00", "jwt_key": "some-key",
"id": "user-token-id", "token": "some-token",
"jwt_key": "some-key", "user_id": "user-id",
"last_used_at": str(now - timedelta(days=10)), "version": "1.2.3",
"token": "some-token", },
"user_id": "user-id", {
"version": "1.2.3", "access_token_expiration": 1800.0,
}, "client_id": "http://localhost:8123/",
{ "created_at": "2018-10-03T13:43:19.774637+00:00",
"access_token_expiration": 1800.0, "id": "user-token-id2",
"client_id": "http://localhost:8123/", "jwt_key": "some-key2",
"created_at": "2018-10-03T13:43:19.774637+00:00", "token": "some-token",
"id": "user-token-id2", "user_id": "user-id",
"jwt_key": "some-key2", "expire_at": 1724133771.079745,
"token": "some-token", },
"user_id": "user-id", ],
}, },
], }
},
}
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
await store.async_load() await store.async_load()
users = await store.async_get_users() users = await store.async_get_users()
assert len(users[0].refresh_tokens) == 2 assert len(users[0].refresh_tokens) == 2
token1, token2 = users[0].refresh_tokens.values() token1, token2 = users[0].refresh_tokens.values()
assert token1.expire_at assert not token1.expire_at
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds() assert token2.expire_at == 1724133771.079745
assert token2.expire_at
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()
async def test_loading_does_not_write_right_away( async def test_loading_does_not_write_right_away(
@ -326,3 +319,63 @@ async def test_add_remove_user_affects_tokens(
assert store.async_get_refresh_token(refresh_token.id) is None assert store.async_get_refresh_token(refresh_token.id) is None
assert store.async_get_refresh_token_by_token(refresh_token.token) is None assert store.async_get_refresh_token_by_token(refresh_token.token) is None
assert user.refresh_tokens == {} assert user.refresh_tokens == {}
async def test_set_expiry_date(
hass: HomeAssistant, hass_storage: dict[str, Any], freezer: FrozenDateTimeFactory
) -> None:
"""Test set expiry date of a refresh token."""
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"token": "some-token",
"user_id": "user-id",
"expire_at": 1724133771.079745,
},
],
},
}
store = auth_store.AuthStore(hass)
await store.async_load()
users = await store.async_get_users()
assert len(users[0].refresh_tokens) == 1
(token,) = users[0].refresh_tokens.values()
assert token.expire_at == 1724133771.079745
store.async_set_expiry(token, enable_expiry=False)
assert token.expire_at is None
freezer.tick(auth_store.DEFAULT_SAVE_DELAY * 2)
# Once for scheduling the task
await hass.async_block_till_done()
# Once for the task
await hass.async_block_till_done()
# verify token is saved without expire_at
assert (
hass_storage[auth_store.STORAGE_KEY]["data"]["refresh_tokens"][0]["expire_at"]
is None
)
store.async_set_expiry(token, enable_expiry=True)
assert token.expire_at is not None

View File

@ -690,3 +690,72 @@ async def test_ws_sign_path(
hass, path, expires = mock_sign.mock_calls[0][1] hass, path, expires = mock_sign.mock_calls[0][1]
assert path == "/api/hello" assert path == "/api/hello"
assert expires.total_seconds() == 20 assert expires.total_seconds() == 20
async def test_ws_refresh_token_set_expiry(
hass: HomeAssistant,
hass_admin_user: MockUser,
hass_admin_credential: Credentials,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
) -> None:
"""Test setting expiry of a refresh token."""
assert await async_setup_component(hass, "auth", {"http": {}})
refresh_token = await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
assert refresh_token.expire_at is not None
ws_client = await hass_ws_client(hass, hass_access_token)
await ws_client.send_json_auto_id(
{
"type": "auth/refresh_token_set_expiry",
"refresh_token_id": refresh_token.id,
"enable_expiry": False,
}
)
result = await ws_client.receive_json()
assert result["success"], result
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
assert refresh_token.expire_at is None
await ws_client.send_json_auto_id(
{
"type": "auth/refresh_token_set_expiry",
"refresh_token_id": refresh_token.id,
"enable_expiry": True,
}
)
result = await ws_client.receive_json()
assert result["success"], result
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
assert refresh_token.expire_at is not None
async def test_ws_refresh_token_set_expiry_error(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
) -> None:
"""Test setting expiry of a invalid refresh token returns error."""
assert await async_setup_component(hass, "auth", {"http": {}})
ws_client = await hass_ws_client(hass, hass_access_token)
await ws_client.send_json_auto_id(
{
"type": "auth/refresh_token_set_expiry",
"refresh_token_id": "invalid",
"enable_expiry": False,
}
)
result = await ws_client.receive_json()
assert result, result["success"] is False
assert result["error"] == {
"code": "invalid_token_id",
"message": "Received invalid token",
}