mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add get_states faster (#23315)
This commit is contained in:
parent
00d26b3049
commit
5b0ee473b6
@ -11,6 +11,7 @@ from .models import PermissionLookup
|
|||||||
from .types import PolicyType
|
from .types import PolicyType
|
||||||
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
|
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
|
||||||
from .merge import merge_policies # noqa
|
from .merge import merge_policies # noqa
|
||||||
|
from .util import test_all
|
||||||
|
|
||||||
|
|
||||||
POLICY_SCHEMA = vol.Schema({
|
POLICY_SCHEMA = vol.Schema({
|
||||||
@ -29,6 +30,10 @@ class AbstractPermissions:
|
|||||||
"""Return a function that can test entity access."""
|
"""Return a function that can test entity access."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def access_all_entities(self, key: str) -> bool:
|
||||||
|
"""Check if we have a certain access to all entities."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def check_entity(self, entity_id: str, key: str) -> bool:
|
def check_entity(self, entity_id: str, key: str) -> bool:
|
||||||
"""Check if we can access entity."""
|
"""Check if we can access entity."""
|
||||||
entity_func = self._cached_entity_func
|
entity_func = self._cached_entity_func
|
||||||
@ -48,6 +53,10 @@ class PolicyPermissions(AbstractPermissions):
|
|||||||
self._policy = policy
|
self._policy = policy
|
||||||
self._perm_lookup = perm_lookup
|
self._perm_lookup = perm_lookup
|
||||||
|
|
||||||
|
def access_all_entities(self, key: str) -> bool:
|
||||||
|
"""Check if we have a certain access to all entities."""
|
||||||
|
return test_all(self._policy.get(CAT_ENTITIES), key)
|
||||||
|
|
||||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||||
"""Return a function that can test entity access."""
|
"""Return a function that can test entity access."""
|
||||||
return compile_entities(self._policy.get(CAT_ENTITIES),
|
return compile_entities(self._policy.get(CAT_ENTITIES),
|
||||||
@ -65,6 +74,10 @@ class _OwnerPermissions(AbstractPermissions):
|
|||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
|
|
||||||
|
def access_all_entities(self, key: str) -> bool:
|
||||||
|
"""Check if we have a certain access to all entities."""
|
||||||
|
return True
|
||||||
|
|
||||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||||
"""Return a function that can test entity access."""
|
"""Return a function that can test entity access."""
|
||||||
return lambda entity_id, key: True
|
return lambda entity_id, key: True
|
||||||
|
@ -3,6 +3,7 @@ from functools import wraps
|
|||||||
|
|
||||||
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401
|
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401
|
||||||
|
|
||||||
|
from .const import SUBCAT_ALL
|
||||||
from .models import PermissionLookup
|
from .models import PermissionLookup
|
||||||
from .types import CategoryType, SubCategoryDict, ValueType
|
from .types import CategoryType, SubCategoryDict, ValueType
|
||||||
|
|
||||||
@ -96,3 +97,16 @@ def _gen_dict_test_func(
|
|||||||
return schema.get(key)
|
return schema.get(key)
|
||||||
|
|
||||||
return test_value
|
return test_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_all(policy: CategoryType, key: str) -> bool:
|
||||||
|
"""Test if a policy has an ALL access for a specific key."""
|
||||||
|
if not isinstance(policy, dict):
|
||||||
|
return bool(policy)
|
||||||
|
|
||||||
|
all_policy = policy.get(SUBCAT_ALL)
|
||||||
|
|
||||||
|
if not isinstance(all_policy, dict):
|
||||||
|
return bool(all_policy)
|
||||||
|
|
||||||
|
return all_policy.get(key, False)
|
||||||
|
@ -142,11 +142,14 @@ def handle_get_states(hass, connection, msg):
|
|||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
entity_perm = connection.user.permissions.check_entity
|
if connection.user.permissions.access_all_entities('read'):
|
||||||
states = [
|
states = hass.states.async_all()
|
||||||
state for state in hass.states.async_all()
|
else:
|
||||||
if entity_perm(state.entity_id, 'read')
|
entity_perm = connection.user.permissions.check_entity
|
||||||
]
|
states = [
|
||||||
|
state for state in hass.states.async_all()
|
||||||
|
if entity_perm(state.entity_id, 'read')
|
||||||
|
]
|
||||||
|
|
||||||
connection.send_message(messages.result_message(
|
connection.send_message(messages.result_message(
|
||||||
msg['id'], states))
|
msg['id'], states))
|
||||||
|
21
tests/auth/permissions/test_util.py
Normal file
21
tests/auth/permissions/test_util.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
"""Test the permission utils."""
|
||||||
|
|
||||||
|
from homeassistant.auth.permissions import util
|
||||||
|
|
||||||
|
|
||||||
|
def test_test_all():
|
||||||
|
"""Test if we can test the all group."""
|
||||||
|
for val in (
|
||||||
|
None,
|
||||||
|
{},
|
||||||
|
{'all': None},
|
||||||
|
{'all': {}},
|
||||||
|
):
|
||||||
|
assert util.test_all(val, 'read') is False
|
||||||
|
|
||||||
|
for val in (
|
||||||
|
True,
|
||||||
|
{'all': True},
|
||||||
|
{'all': {'read': True}},
|
||||||
|
):
|
||||||
|
assert util.test_all(val, 'read') is True
|
Loading…
x
Reference in New Issue
Block a user