From b1d0c6a4f13b759ad4a07deec1f6ea3801fd98f1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 13 Jan 2024 10:10:50 -1000 Subject: [PATCH] Refactor User attribute caching to be safer and more efficient (#96723) * Cache construction of is_admin This has to be checked for a lot of api calls and the websocket every time the call is made * Cache construction of is_admin This has to be checked for a lot of api calls and the websocket every time the call is made * Cache construction of is_admin This has to be checked for a lot of api calls and the websocket every time the call is made * modernize * coverage * coverage * verify caching * verify caching * fix type * fix mocking --- homeassistant/auth/auth_store.py | 1 - homeassistant/auth/models.py | 63 +++++++++++++++++-------------- tests/auth/test_models.py | 34 +++++++++++++++++ tests/common.py | 2 +- tests/components/api/test_init.py | 2 + 5 files changed, 72 insertions(+), 30 deletions(-) diff --git a/homeassistant/auth/auth_store.py b/homeassistant/auth/auth_store.py index 50d5d630429..c8f5001a515 100644 --- a/homeassistant/auth/auth_store.py +++ b/homeassistant/auth/auth_store.py @@ -171,7 +171,6 @@ class AuthStore: groups.append(group) user.groups = groups - user.invalidate_permission_cache() for attr_name, value in ( ("name", name), diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 32a700d65f9..574f0cc75c0 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -3,10 +3,12 @@ from __future__ import annotations from datetime import datetime, timedelta import secrets -from typing import NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import uuid import attr +from attr import Attribute +from attr.setters import validate from homeassistant.const import __version__ from homeassistant.util import dt as dt_util @@ -14,6 +16,12 @@ from homeassistant.util import dt as dt_util from . import permissions as perm_mdl from .const import GROUP_ID_ADMIN +if TYPE_CHECKING: + from functools import cached_property +else: + from homeassistant.backports.functools import cached_property + + TOKEN_TYPE_NORMAL = "normal" TOKEN_TYPE_SYSTEM = "system" TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token" @@ -29,19 +37,27 @@ class Group: system_generated: bool = attr.ib(default=False) -@attr.s(slots=True) +def _handle_permissions_change(self: User, user_attr: Attribute, new: Any) -> Any: + """Handle a change to a permissions.""" + self.invalidate_cache() + return validate(self, user_attr, new) + + +@attr.s(slots=False) class User: """A user.""" name: str | None = attr.ib() perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False) id: str = attr.ib(factory=lambda: uuid.uuid4().hex) - is_owner: bool = attr.ib(default=False) - is_active: bool = attr.ib(default=False) + is_owner: bool = attr.ib(default=False, on_setattr=_handle_permissions_change) + is_active: bool = attr.ib(default=False, on_setattr=_handle_permissions_change) system_generated: bool = attr.ib(default=False) local_only: bool = attr.ib(default=False) - groups: list[Group] = attr.ib(factory=list, eq=False, order=False) + groups: list[Group] = attr.ib( + factory=list, eq=False, order=False, on_setattr=_handle_permissions_change + ) # List of credentials of a user. credentials: list[Credentials] = attr.ib(factory=list, eq=False, order=False) @@ -51,40 +67,31 @@ class User: factory=dict, eq=False, order=False ) - _permissions: perm_mdl.PolicyPermissions | None = attr.ib( - init=False, - eq=False, - order=False, - default=None, - ) - - @property + @cached_property def permissions(self) -> perm_mdl.AbstractPermissions: """Return permissions object for user.""" if self.is_owner: return perm_mdl.OwnerPermissions - - if self._permissions is not None: - return self._permissions - - self._permissions = perm_mdl.PolicyPermissions( + return perm_mdl.PolicyPermissions( perm_mdl.merge_policies([group.policy for group in self.groups]), self.perm_lookup, ) - return self._permissions - - @property + @cached_property def is_admin(self) -> bool: """Return if user is part of the admin group.""" - if self.is_owner: - return True + return self.is_owner or ( + self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups) + ) - return self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups) - - def invalidate_permission_cache(self) -> None: - """Invalidate permission cache.""" - self._permissions = None + def invalidate_cache(self) -> None: + """Invalidate permission and is_admin cache.""" + for attr_to_invalidate in ("permissions", "is_admin"): + # try is must more efficient than suppress + try: # noqa: SIM105 + delattr(self, attr_to_invalidate) + except AttributeError: + pass @attr.s(slots=True) diff --git a/tests/auth/test_models.py b/tests/auth/test_models.py index 1c518cf061d..3f0ad7acc1d 100644 --- a/tests/auth/test_models.py +++ b/tests/auth/test_models.py @@ -26,3 +26,37 @@ def test_permissions_merged() -> None: assert user.permissions.check_entity("switch.bla", "read") is True assert user.permissions.check_entity("light.kitchen", "read") is True assert user.permissions.check_entity("light.not_kitchen", "read") is False + + +def test_cache_cleared_on_group_change() -> None: + """Test we clear the cache when a group changes.""" + group = models.Group( + name="Test Group", policy={"entities": {"domains": {"switch": True}}} + ) + admin_group = models.Group( + name="Admin group", id=models.GROUP_ID_ADMIN, policy={"entities": {}} + ) + user = models.User( + name="Test User", perm_lookup=None, groups=[group], is_active=True + ) + # Make sure we cache instance + assert user.permissions is user.permissions + + # Make sure we cache is_admin + assert user.is_admin is user.is_admin + assert user.is_active is True + + user.groups = [] + assert user.groups == [] + assert user.is_admin is False + + user.is_owner = True + assert user.is_admin is True + user.is_owner = False + + assert user.is_admin is False + user.groups = [admin_group] + assert user.is_admin is True + + user.is_active = False + assert user.is_admin is False diff --git a/tests/common.py b/tests/common.py index 02c7150588d..35171799728 100644 --- a/tests/common.py +++ b/tests/common.py @@ -669,7 +669,7 @@ class MockUser(auth_models.User): def mock_policy(self, policy): """Mock a policy for a user.""" - self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup) + self.permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup) async def register_auth_provider( diff --git a/tests/components/api/test_init.py b/tests/components/api/test_init.py index 08cb77b4559..d9c8e7481fa 100644 --- a/tests/components/api/test_init.py +++ b/tests/components/api/test_init.py @@ -684,6 +684,8 @@ async def test_get_entity_state_read_perm( ) -> None: """Test getting a state requires read permission.""" hass_admin_user.mock_policy({}) + hass_admin_user.groups = [] + assert hass_admin_user.is_admin is False resp = await mock_api_client.get("/api/states/light.test") assert resp.status == HTTPStatus.UNAUTHORIZED