Add get_states faster (#23315)

This commit is contained in:
Paulus Schoutsen 2019-04-23 03:46:23 -07:00 committed by Pascal Vizeli
parent 00d26b3049
commit 5b0ee473b6
4 changed files with 56 additions and 5 deletions

View File

@ -11,6 +11,7 @@ from .models import PermissionLookup
from .types import PolicyType
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
from .merge import merge_policies # noqa
from .util import test_all
POLICY_SCHEMA = vol.Schema({
@ -29,6 +30,10 @@ class AbstractPermissions:
"""Return a function that can test entity access."""
raise NotImplementedError
def access_all_entities(self, key: str) -> bool:
"""Check if we have a certain access to all entities."""
raise NotImplementedError
def check_entity(self, entity_id: str, key: str) -> bool:
"""Check if we can access entity."""
entity_func = self._cached_entity_func
@ -48,6 +53,10 @@ class PolicyPermissions(AbstractPermissions):
self._policy = policy
self._perm_lookup = perm_lookup
def access_all_entities(self, key: str) -> bool:
"""Check if we have a certain access to all entities."""
return test_all(self._policy.get(CAT_ENTITIES), key)
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return compile_entities(self._policy.get(CAT_ENTITIES),
@ -65,6 +74,10 @@ class _OwnerPermissions(AbstractPermissions):
# pylint: disable=no-self-use
def access_all_entities(self, key: str) -> bool:
"""Check if we have a certain access to all entities."""
return True
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return lambda entity_id, key: True

View File

@ -3,6 +3,7 @@ from functools import wraps
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401
from .const import SUBCAT_ALL
from .models import PermissionLookup
from .types import CategoryType, SubCategoryDict, ValueType
@ -96,3 +97,16 @@ def _gen_dict_test_func(
return schema.get(key)
return test_value
def test_all(policy: CategoryType, key: str) -> bool:
"""Test if a policy has an ALL access for a specific key."""
if not isinstance(policy, dict):
return bool(policy)
all_policy = policy.get(SUBCAT_ALL)
if not isinstance(all_policy, dict):
return bool(all_policy)
return all_policy.get(key, False)

View File

@ -142,11 +142,14 @@ def handle_get_states(hass, connection, msg):
Async friendly.
"""
entity_perm = connection.user.permissions.check_entity
states = [
state for state in hass.states.async_all()
if entity_perm(state.entity_id, 'read')
]
if connection.user.permissions.access_all_entities('read'):
states = hass.states.async_all()
else:
entity_perm = connection.user.permissions.check_entity
states = [
state for state in hass.states.async_all()
if entity_perm(state.entity_id, 'read')
]
connection.send_message(messages.result_message(
msg['id'], states))

View File

@ -0,0 +1,21 @@
"""Test the permission utils."""
from homeassistant.auth.permissions import util
def test_test_all():
"""Test if we can test the all group."""
for val in (
None,
{},
{'all': None},
{'all': {}},
):
assert util.test_all(val, 'read') is False
for val in (
True,
{'all': True},
{'all': {'read': True}},
):
assert util.test_all(val, 'read') is True