From a57f4b8f42319ab327d7d31416b973ec91838d6c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 5 May 2024 15:47:26 -0500 Subject: [PATCH] Index auth token ids to avoid linear search (#116583) * Index auth token ids to avoid linear search * async_remove_refresh_token * coverage --- homeassistant/auth/auth_store.py | 38 ++++++++++++++++++++++---------- tests/auth/test_auth_store.py | 21 ++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 826bec57ee6..bf93011355c 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -63,6 +63,7 @@ class AuthStore: self._store = Store[dict[str, list[dict[str, Any]]]]( hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True ) + self._token_id_to_user_id: dict[str, str] = {} async def async_get_groups(self) -> list[models.Group]: """Retrieve all users.""" @@ -136,7 +137,10 @@ class AuthStore: async def async_remove_user(self, user: models.User) -> None: """Remove a user.""" - self._users.pop(user.id) + user = self._users.pop(user.id) + for refresh_token_id in user.refresh_tokens: + del self._token_id_to_user_id[refresh_token_id] + user.refresh_tokens.clear() self._async_schedule_save() async def async_update_user( @@ -219,7 +223,9 @@ class AuthStore: kwargs["client_icon"] = client_icon refresh_token = models.RefreshToken(**kwargs) - user.refresh_tokens[refresh_token.id] = refresh_token + token_id = refresh_token.id + user.refresh_tokens[token_id] = refresh_token + self._token_id_to_user_id[token_id] = user.id self._async_schedule_save() return refresh_token @@ -227,19 +233,17 @@ class AuthStore: @callback def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None: """Remove a refresh token.""" - for user in self._users.values(): - if user.refresh_tokens.pop(refresh_token.id, None): - self._async_schedule_save() - break + refresh_token_id = refresh_token.id + if user_id := self._token_id_to_user_id.get(refresh_token_id): + del self._users[user_id].refresh_tokens[refresh_token_id] + del self._token_id_to_user_id[refresh_token_id] + self._async_schedule_save() @callback def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None: """Get refresh token by id.""" - for user in self._users.values(): - refresh_token = user.refresh_tokens.get(token_id) - if refresh_token is not None: - return refresh_token - + if user_id := self._token_id_to_user_id.get(token_id): + return self._users[user_id].refresh_tokens.get(token_id) return None @callback @@ -479,9 +483,18 @@ class AuthStore: self._groups = groups self._users = users - + self._build_token_id_to_user_id() self._async_schedule_save(INITIAL_LOAD_SAVE_DELAY) + @callback + def _build_token_id_to_user_id(self) -> None: + """Build a map of token id to user id.""" + self._token_id_to_user_id = { + token_id: user_id + for user_id, user in self._users.items() + for token_id in user.refresh_tokens + } + @callback def _async_schedule_save(self, delay: float = DEFAULT_SAVE_DELAY) -> None: """Save users.""" @@ -575,6 +588,7 @@ class AuthStore: read_only_group = _system_read_only_group() groups[read_only_group.id] = read_only_group self._groups = groups + self._build_token_id_to_user_id() def _system_admin_group() -> models.Group: diff --git a/tests/auth/test_auth_store.py b/tests/auth/test_auth_store.py index 3d62190eab6..8ef8a4e3946 100644 --- a/tests/auth/test_auth_store.py +++ b/tests/auth/test_auth_store.py @@ -305,3 +305,24 @@ async def test_loading_does_not_write_right_away( # Once for the task await hass.async_block_till_done() assert hass_storage[auth_store.STORAGE_KEY] != {} + + +async def test_add_remove_user_affects_tokens( + hass: HomeAssistant, hass_storage: dict[str, Any] +) -> None: + """Test adding and removing a user removes the tokens.""" + store = auth_store.AuthStore(hass) + await store.async_load() + user = await store.async_create_user("Test User") + assert user.name == "Test User" + refresh_token = await store.async_create_refresh_token( + user, "client_id", "access_token_expiration" + ) + assert user.refresh_tokens == {refresh_token.id: refresh_token} + assert await store.async_get_user(user.id) == user + assert store.async_get_refresh_token(refresh_token.id) == refresh_token + assert store.async_get_refresh_token_by_token(refresh_token.token) == refresh_token + await store.async_remove_user(user) + assert store.async_get_refresh_token(refresh_token.id) is None + assert store.async_get_refresh_token_by_token(refresh_token.token) is None + assert user.refresh_tokens == {}