Small performance improvements to handing revoke token callbacks (#108625)

- Use a set to avoid linear search for remove
- Avoid recreating the unregister function each time
This commit is contained in:
J. Nick Koston 2024-01-21 17:49:06 -10:00 committed by GitHub
parent 3d3f4ac293
commit 740209912c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ import asyncio
from collections import OrderedDict
from collections.abc import Mapping
from datetime import timedelta
from functools import partial
import time
from typing import Any, cast
@ -157,7 +158,7 @@ class AuthManager:
self._providers = providers
self._mfa_modules = mfa_modules
self.login_flow = AuthManagerFlowManager(hass, self)
self._revoke_callbacks: dict[str, list[CALLBACK_TYPE]] = {}
self._revoke_callbacks: dict[str, set[CALLBACK_TYPE]] = {}
@property
def auth_providers(self) -> list[AuthProvider]:
@ -475,27 +476,28 @@ class AuthManager:
"""Delete a refresh token."""
await self._store.async_remove_refresh_token(refresh_token)
callbacks = self._revoke_callbacks.pop(refresh_token.id, [])
callbacks = self._revoke_callbacks.pop(refresh_token.id, ())
for revoke_callback in callbacks:
revoke_callback()
@callback
def _async_unregister(
self, callbacks: set[CALLBACK_TYPE], callback_: CALLBACK_TYPE
) -> None:
"""Unregister a callback."""
callbacks.remove(callback_)
@callback
def async_register_revoke_token_callback(
self, refresh_token_id: str, revoke_callback: CALLBACK_TYPE
) -> CALLBACK_TYPE:
"""Register a callback to be called when the refresh token id is revoked."""
if refresh_token_id not in self._revoke_callbacks:
self._revoke_callbacks[refresh_token_id] = []
self._revoke_callbacks[refresh_token_id] = set()
callbacks = self._revoke_callbacks[refresh_token_id]
callbacks.append(revoke_callback)
@callback
def unregister() -> None:
if revoke_callback in callbacks:
callbacks.remove(revoke_callback)
return unregister
callbacks.add(revoke_callback)
return partial(self._async_unregister, callbacks, revoke_callback)
@callback
def async_create_access_token(