From f5d439799b0886396eba364cd6c7ec5b674a4027 Mon Sep 17 00:00:00 2001 From: Michael <35783820+mib1185@users.noreply.github.com> Date: Thu, 25 Jan 2024 00:24:22 +0100 Subject: [PATCH] Add expiration of unused refresh tokens (#108428) Co-authored-by: J. Nick Koston --- homeassistant/auth/__init__.py | 71 +++++++++++++++++++++++++++-- homeassistant/auth/auth_store.py | 33 +++++++++++++- homeassistant/auth/const.py | 1 + homeassistant/auth/models.py | 2 + tests/auth/test_auth_store.py | 67 ++++++++++++++++++++++++++++ tests/auth/test_init.py | 76 +++++++++++++++++++++++++++++++- 6 files changed, 243 insertions(+), 7 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index 0194be10ba9..15094681454 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from collections import OrderedDict from collections.abc import Mapping -from datetime import timedelta +from datetime import datetime, timedelta from functools import partial import time from typing import Any, cast @@ -12,11 +12,19 @@ from typing import Any, cast import jwt from homeassistant import data_entry_flow -from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback +from homeassistant.core import ( + CALLBACK_TYPE, + HassJob, + HassJobType, + HomeAssistant, + callback, +) from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers.event import async_track_point_in_utc_time +from homeassistant.util import dt as dt_util from . import auth_store, jwt_wrapper, models -from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN +from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config from .providers import AuthProvider, LoginFlow, auth_provider_from_config @@ -75,7 +83,9 @@ async def auth_manager_from_config( for module in modules: module_hash[module.id] = module - return AuthManager(hass, store, provider_hash, module_hash) + manager = AuthManager(hass, store, provider_hash, module_hash) + manager.async_setup() + return manager class AuthManagerFlowManager(data_entry_flow.FlowManager): @@ -159,6 +169,21 @@ class AuthManager: self._mfa_modules = mfa_modules self.login_flow = AuthManagerFlowManager(hass, self) self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {} + self._expire_callback: CALLBACK_TYPE | None = None + self._remove_expired_job = HassJob( + self._async_remove_expired_refresh_tokens, job_type=HassJobType.Callback + ) + + @callback + def async_setup(self) -> None: + """Set up the auth manager.""" + hass = self.hass + hass.async_add_shutdown_job( + HassJob( + self._async_cancel_expiration_schedule, job_type=HassJobType.Callback + ) + ) + self._async_track_next_refresh_token_expiration() @property def auth_providers(self) -> list[AuthProvider]: @@ -424,6 +449,11 @@ class AuthManager: else: token_type = models.TOKEN_TYPE_NORMAL + if token_type is models.TOKEN_TYPE_NORMAL: + expire_at = time.time() + REFRESH_TOKEN_EXPIRATION + else: + expire_at = None + if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM): raise ValueError( "System generated users can only have system type refresh tokens" @@ -455,6 +485,7 @@ class AuthManager: client_icon, token_type, access_token_expiration, + expire_at, credential, ) @@ -479,6 +510,38 @@ class AuthManager: for revoke_callback in callbacks: revoke_callback() + @callback + def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None: + """Remove expired refresh tokens.""" + now = time.time() + for token in self._store.async_get_refresh_tokens()[:]: + if (expire_at := token.expire_at) is not None and expire_at <= now: + self.async_remove_refresh_token(token) + self._async_track_next_refresh_token_expiration() + + @callback + def _async_track_next_refresh_token_expiration(self) -> None: + """Initialise all token expiration scheduled tasks.""" + next_expiration = time.time() + REFRESH_TOKEN_EXPIRATION + for token in self._store.async_get_refresh_tokens(): + if ( + expire_at := token.expire_at + ) is not None and expire_at < next_expiration: + next_expiration = expire_at + + self._expire_callback = async_track_point_in_utc_time( + self.hass, + self._remove_expired_job, + dt_util.utc_from_timestamp(next_expiration), + ) + + @callback + def _async_cancel_expiration_schedule(self) -> None: + """Cancel tracking of expired refresh tokens.""" + if self._expire_callback: + self._expire_callback() + self._expire_callback = None + @callback def _async_unregister( self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 6d63f9bfd50..983ba7da6a1 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import timedelta import hmac +import itertools from logging import getLogger from typing import Any @@ -17,6 +18,7 @@ from .const import ( GROUP_ID_ADMIN, GROUP_ID_READ_ONLY, GROUP_ID_USER, + REFRESH_TOKEN_EXPIRATION, ) from .permissions import system_policies from .permissions.models import PermissionLookup @@ -186,6 +188,7 @@ class AuthStore: client_icon: str | None = None, token_type: str = models.TOKEN_TYPE_NORMAL, access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION, + expire_at: float | None = None, credential: models.Credentials | None = None, ) -> models.RefreshToken: """Create a new token for a user.""" @@ -194,6 +197,7 @@ class AuthStore: "client_id": client_id, "token_type": token_type, "access_token_expiration": access_token_expiration, + "expire_at": expire_at, "credential": credential, } if client_name: @@ -239,6 +243,15 @@ class AuthStore: return found + @callback + def async_get_refresh_tokens(self) -> list[models.RefreshToken]: + """Get all refresh tokens.""" + return list( + itertools.chain.from_iterable( + user.refresh_tokens.values() for user in self._users.values() + ) + ) + @callback def async_log_refresh_token_usage( self, refresh_token: models.RefreshToken, remote_ip: str | None = None @@ -246,9 +259,13 @@ class AuthStore: """Update refresh token last used information.""" refresh_token.last_used_at = dt_util.utcnow() refresh_token.last_used_ip = remote_ip + if refresh_token.expire_at: + refresh_token.expire_at = ( + refresh_token.last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION + ) self._async_schedule_save() - async def async_load(self) -> None: + async def async_load(self) -> None: # noqa: C901 """Load the users.""" if self._loaded: raise RuntimeError("Auth storage is already loaded") @@ -261,6 +278,8 @@ class AuthStore: perm_lookup = PermissionLookup(ent_reg, dev_reg) self._perm_lookup = perm_lookup + now_ts = dt_util.utcnow().timestamp() + if data is None or not isinstance(data, dict): self._set_defaults() return @@ -414,6 +433,14 @@ class AuthStore: else: last_used_at = None + if ( + expire_at := rt_dict.get("expire_at") + ) is None and token_type == models.TOKEN_TYPE_NORMAL: + if last_used_at: + expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION + else: + expire_at = now_ts + REFRESH_TOKEN_EXPIRATION + token = models.RefreshToken( id=rt_dict["id"], user=users[rt_dict["user_id"]], @@ -430,6 +457,7 @@ class AuthStore: jwt_key=rt_dict["jwt_key"], last_used_at=last_used_at, last_used_ip=rt_dict.get("last_used_ip"), + expire_at=expire_at, version=rt_dict.get("version"), ) if "credential_id" in rt_dict: @@ -439,6 +467,8 @@ class AuthStore: self._groups = groups self._users = users + self._async_schedule_save() + @callback def _async_schedule_save(self) -> None: """Save users.""" @@ -503,6 +533,7 @@ class AuthStore: if refresh_token.last_used_at else None, "last_used_ip": refresh_token.last_used_ip, + "expire_at": refresh_token.expire_at, "credential_id": refresh_token.credential.id if refresh_token.credential else None, diff --git a/homeassistant/auth/const.py b/homeassistant/auth/const.py index 5e17e752bdd..704f5d1d57c 100644 --- a/homeassistant/auth/const.py +++ b/homeassistant/auth/const.py @@ -3,6 +3,7 @@ from datetime import timedelta ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30) MFA_SESSION_EXPIRATION = timedelta(minutes=5) +REFRESH_TOKEN_EXPIRATION = timedelta(days=90).total_seconds() GROUP_ID_ADMIN = "system-admin" GROUP_ID_USER = "system-users" diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 574f0cc75c0..4cf94401478 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -117,6 +117,8 @@ class RefreshToken: last_used_at: datetime | None = attr.ib(default=None) last_used_ip: str | None = attr.ib(default=None) + expire_at: float | None = attr.ib(default=None) + credential: Credentials | None = attr.ib(default=None) version: str | None = attr.ib(default=__version__) diff --git a/tests/auth/test_auth_store.py b/tests/auth/test_auth_store.py index 778095388a8..858d4d082b1 100644 --- a/tests/auth/test_auth_store.py +++ b/tests/auth/test_auth_store.py @@ -1,12 +1,15 @@ """Tests for the auth store.""" import asyncio +from datetime import timedelta from typing import Any from unittest.mock import patch +from freezegun import freeze_time import pytest from homeassistant.auth import auth_store from homeassistant.core import HomeAssistant +from homeassistant.util import dt as dt_util async def test_loading_no_group_data_format( @@ -267,3 +270,67 @@ async def test_loading_only_once(hass: HomeAssistant) -> None: mock_dev_registry.assert_called_once_with(hass) mock_load.assert_called_once_with() assert results[0] == results[1] + + +async def test_add_expire_at_property( + hass: HomeAssistant, hass_storage: dict[str, Any] +) -> None: + """Test we correctly add expired_at property if not existing.""" + now = dt_util.utcnow() + with freeze_time(now): + hass_storage[auth_store.STORAGE_KEY] = { + "version": 1, + "data": { + "credentials": [], + "users": [ + { + "id": "user-id", + "is_active": True, + "is_owner": True, + "name": "Paulus", + "system_generated": False, + }, + { + "id": "system-id", + "is_active": True, + "is_owner": True, + "name": "Hass.io", + "system_generated": True, + }, + ], + "refresh_tokens": [ + { + "access_token_expiration": 1800.0, + "client_id": "http://localhost:8123/", + "created_at": "2018-10-03T13:43:19.774637+00:00", + "id": "user-token-id", + "jwt_key": "some-key", + "last_used_at": str(now - timedelta(days=10)), + "token": "some-token", + "user_id": "user-id", + "version": "1.2.3", + }, + { + "access_token_expiration": 1800.0, + "client_id": "http://localhost:8123/", + "created_at": "2018-10-03T13:43:19.774637+00:00", + "id": "user-token-id2", + "jwt_key": "some-key2", + "token": "some-token", + "user_id": "user-id", + }, + ], + }, + } + + store = auth_store.AuthStore(hass) + await store.async_load() + + users = await store.async_get_users() + + assert len(users[0].refresh_tokens) == 2 + token1, token2 = users[0].refresh_tokens.values() + assert token1.expire_at + assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds() + assert token2.expire_at + assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds() diff --git a/tests/auth/test_init.py b/tests/auth/test_init.py index 5e08f5e3aeb..b561b17112b 100644 --- a/tests/auth/test_init.py +++ b/tests/auth/test_init.py @@ -26,6 +26,7 @@ from tests.common import ( CLIENT_ID, MockUser, async_capture_events, + async_fire_time_changed, ensure_auth_manager_loaded, flush_store, ) @@ -406,6 +407,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None: assert not user.local_only assert token is not None assert token.client_id is None + assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM + assert token.expire_at is None await hass.async_block_till_done() assert len(events) == 1 @@ -421,6 +424,8 @@ async def test_generating_system_user(hass: HomeAssistant) -> None: assert user.local_only assert token is not None assert token.client_id is None + assert token.token_type == auth.models.TOKEN_TYPE_SYSTEM + assert token.expire_at is None await hass.async_block_till_done() assert len(events) == 2 @@ -474,6 +479,8 @@ async def test_refresh_token_with_specific_access_token_expiration( assert token is not None assert token.client_id == CLIENT_ID assert token.access_token_expiration == timedelta(days=100) + assert token.token_type == auth.models.TOKEN_TYPE_NORMAL + assert token.expire_at is not None async def test_refresh_token_type(hass: HomeAssistant) -> None: @@ -515,6 +522,7 @@ async def test_refresh_token_type_long_lived_access_token(hass: HomeAssistant) - assert token.client_name == "GPS LOGGER" assert token.client_icon == "mdi:home" assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN + assert token.expire_at is None async def test_refresh_token_provider_validation(mock_hass) -> None: @@ -565,9 +573,9 @@ async def test_cannot_deactive_owner(mock_hass) -> None: await manager.async_deactivate_user(owner) -async def test_remove_refresh_token(mock_hass) -> None: +async def test_remove_refresh_token(hass: HomeAssistant) -> None: """Test that we can remove a refresh token.""" - manager = await auth.auth_manager_from_config(mock_hass, [], []) + manager = await auth.auth_manager_from_config(hass, [], []) user = MockUser().add_to_auth_manager(manager) refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) access_token = manager.async_create_access_token(refresh_token) @@ -578,6 +586,70 @@ async def test_remove_refresh_token(mock_hass) -> None: assert manager.async_validate_access_token(access_token) is None +async def test_remove_expired_refresh_token(hass: HomeAssistant) -> None: + """Test that expired refresh tokens are deleted.""" + manager = await auth.auth_manager_from_config(hass, [], []) + user = MockUser().add_to_auth_manager(manager) + now = dt_util.utcnow() + with freeze_time(now): + refresh_token1 = await manager.async_create_refresh_token(user, CLIENT_ID) + assert ( + refresh_token1.expire_at + == now.timestamp() + timedelta(days=90).total_seconds() + ) + + with freeze_time(now + timedelta(days=30)): + async_fire_time_changed(hass, now + timedelta(days=30)) + refresh_token2 = await manager.async_create_refresh_token(user, CLIENT_ID) + assert ( + refresh_token2.expire_at + == now.timestamp() + timedelta(days=120).total_seconds() + ) + + with freeze_time(now + timedelta(days=89, hours=23)): + async_fire_time_changed(hass, now + timedelta(days=89, hours=23)) + await hass.async_block_till_done() + assert manager.async_get_refresh_token(refresh_token1.id) + assert manager.async_get_refresh_token(refresh_token2.id) + + with freeze_time(now + timedelta(days=90, seconds=5)): + async_fire_time_changed(hass, now + timedelta(days=90, seconds=5)) + await hass.async_block_till_done() + assert manager.async_get_refresh_token(refresh_token1.id) is None + assert manager.async_get_refresh_token(refresh_token2.id) + + with freeze_time(now + timedelta(days=120, seconds=5)): + async_fire_time_changed(hass, now + timedelta(days=120, seconds=5)) + await hass.async_block_till_done() + assert manager.async_get_refresh_token(refresh_token1.id) is None + assert manager.async_get_refresh_token(refresh_token2.id) is None + + +async def test_update_expire_at_refresh_token(hass: HomeAssistant) -> None: + """Test that expire at is updated when refresh token is used.""" + manager = await auth.auth_manager_from_config(hass, [], []) + user = MockUser().add_to_auth_manager(manager) + now = dt_util.utcnow() + with freeze_time(now): + refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID) + assert ( + refresh_token.expire_at + == now.timestamp() + timedelta(days=90).total_seconds() + ) + + with freeze_time(now + timedelta(days=30)): + async_fire_time_changed(hass, now + timedelta(days=30)) + await hass.async_block_till_done() + assert manager.async_create_access_token(refresh_token) + await hass.async_block_till_done() + assert ( + refresh_token.expire_at + == now.timestamp() + + timedelta(days=30).total_seconds() + + timedelta(days=90).total_seconds() + ) + + async def test_register_revoke_token_callback(mock_hass) -> None: """Test that a registered revoke token callback is called.""" manager = await auth.auth_manager_from_config(mock_hass, [], [])