Add area permission check (#21835)

This commit is contained in:
Paulus Schoutsen 2019-03-11 11:02:37 -07:00 committed by GitHub
parent 4f49bdf262
commit 4f5446ff02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 230 additions and 146 deletions

View File

@ -281,8 +281,9 @@ class AuthStore:
async def _async_load_task(self) -> None: async def _async_load_task(self) -> None:
"""Load the users.""" """Load the users."""
[ent_reg, data] = await asyncio.gather( [ent_reg, dev_reg, data] = await asyncio.gather(
self.hass.helpers.entity_registry.async_get_registry(), self.hass.helpers.entity_registry.async_get_registry(),
self.hass.helpers.device_registry.async_get_registry(),
self._store.async_load(), self._store.async_load(),
) )
@ -291,7 +292,9 @@ class AuthStore:
if self._users is not None: if self._users is not None:
return return
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg) self._perm_lookup = perm_lookup = PermissionLookup(
ent_reg, dev_reg
)
if data is None: if data is None:
self._set_defaults() self._set_defaults()

View File

@ -1,12 +1,14 @@
"""Entity permissions.""" """Entity permissions."""
from functools import wraps from collections import OrderedDict
from typing import Callable, List, Union # noqa: F401 from typing import Callable, Optional # noqa: F401
import voluptuous as vol import voluptuous as vol
from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT
from .models import PermissionLookup from .models import PermissionLookup
from .types import CategoryType, ValueType from .types import CategoryType, SubCategoryDict, ValueType
# pylint: disable=unused-import
from .util import SubCatLookupType, lookup_all, compile_policy # noqa
SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
vol.Optional(POLICY_READ): True, vol.Optional(POLICY_READ): True,
@ -15,6 +17,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
})) }))
ENTITY_DOMAINS = 'domains' ENTITY_DOMAINS = 'domains'
ENTITY_AREAS = 'area_ids'
ENTITY_DEVICE_IDS = 'device_ids' ENTITY_DEVICE_IDS = 'device_ids'
ENTITY_ENTITY_IDS = 'entity_ids' ENTITY_ENTITY_IDS = 'entity_ids'
@ -24,148 +27,65 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA, vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA,
vol.Optional(ENTITY_AREAS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA,
})) }))
def _entity_allowed(schema: ValueType, key: str) \ def _lookup_domain(perm_lookup: PermissionLookup,
-> Union[bool, None]: domains_dict: SubCategoryDict,
"""Test if an entity is allowed based on the keys.""" entity_id: str) -> Optional[ValueType]:
if schema is None or isinstance(schema, bool): """Look up entity permissions by domain."""
return schema return domains_dict.get(entity_id.split(".", 1)[0])
assert isinstance(schema, dict)
return schema.get(key)
def _lookup_area(perm_lookup: PermissionLookup, area_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permissions by area."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
device_entry = perm_lookup.device_registry.async_get(
entity_entry.device_id
)
if device_entry is None or device_entry.area_id is None:
return None
return area_dict.get(device_entry.area_id)
def _lookup_device(perm_lookup: PermissionLookup,
devices_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permissions by device."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
return devices_dict.get(entity_entry.device_id)
def _lookup_entity_id(perm_lookup: PermissionLookup,
entities_dict: SubCategoryDict,
entity_id: str) -> Optional[ValueType]:
"""Look up entity permission by entity id."""
return entities_dict.get(entity_id)
def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \ def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \
-> Callable[[str, str], bool]: -> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy.""" """Compile policy into a function that tests policy."""
# None, Empty Dict, False subcategories = OrderedDict() # type: SubCatLookupType
if not policy: subcategories[ENTITY_ENTITY_IDS] = _lookup_entity_id
def apply_policy_deny_all(entity_id: str, key: str) -> bool: subcategories[ENTITY_DEVICE_IDS] = _lookup_device
"""Decline all.""" subcategories[ENTITY_AREAS] = _lookup_area
return False subcategories[ENTITY_DOMAINS] = _lookup_domain
subcategories[SUBCAT_ALL] = lookup_all
return apply_policy_deny_all return compile_policy(policy, subcategories, perm_lookup)
if policy is True:
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all."""
return True
return apply_policy_allow_all
assert isinstance(policy, dict)
domains = policy.get(ENTITY_DOMAINS)
device_ids = policy.get(ENTITY_DEVICE_IDS)
entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL)
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
# The order of these functions matter. The more precise are at the top.
# If a function returns None, they cannot handle it.
# If a function returns a boolean, that's the result to return.
# Setting entity_ids to a boolean is final decision for permissions
# So return right away.
if isinstance(entity_ids, bool):
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
"""Test if allowed entity_id."""
return entity_ids # type: ignore
return allowed_entity_id_bool
if entity_ids is not None:
def allowed_entity_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed entity_id."""
return _entity_allowed(
entity_ids.get(entity_id), key) # type: ignore
funcs.append(allowed_entity_id_dict)
if isinstance(device_ids, bool):
def allowed_device_id_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
return device_ids
funcs.append(allowed_device_id_bool)
elif device_ids is not None:
def allowed_device_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
return _entity_allowed(
device_ids.get(entity_entry.device_id), key # type: ignore
)
funcs.append(allowed_device_id_dict)
if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return domains
funcs.append(allowed_domain_bool)
elif domains is not None:
def allowed_domain_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
domain = entity_id.split(".", 1)[0]
return _entity_allowed(domains.get(domain), key) # type: ignore
funcs.append(allowed_domain_dict)
if isinstance(all_entities, bool):
def allowed_all_entities_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return all_entities
funcs.append(allowed_all_entities_bool)
elif all_entities is not None:
def allowed_all_entities_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return _entity_allowed(all_entities, key)
funcs.append(allowed_all_entities_dict)
# Can happen if no valid subcategories specified
if not funcs:
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
return apply_policy_deny_all_2
if len(funcs) == 1:
func = funcs[0]
@wraps(func)
def apply_policy_func(entity_id: str, key: str) -> bool:
"""Apply a single policy function."""
return func(entity_id, key) is True
return apply_policy_func
def apply_policy_funcs(entity_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(entity_id, key)
if result is not None:
return result
return False
return apply_policy_funcs

View File

@ -8,6 +8,9 @@ if TYPE_CHECKING:
from homeassistant.helpers import ( # noqa from homeassistant.helpers import ( # noqa
entity_registry as ent_reg, entity_registry as ent_reg,
) )
from homeassistant.helpers import ( # noqa
device_registry as dev_reg,
)
@attr.s(slots=True) @attr.s(slots=True)
@ -15,3 +18,4 @@ class PermissionLookup:
"""Class to hold data for permission lookups.""" """Class to hold data for permission lookups."""
entity_registry = attr.ib(type='ent_reg.EntityRegistry') entity_registry = attr.ib(type='ent_reg.EntityRegistry')
device_registry = attr.ib(type='dev_reg.DeviceRegistry')

View File

@ -10,9 +10,11 @@ ValueType = Union[
None None
] ]
# Example: entities.domains = { light: … }
SubCategoryDict = Mapping[str, ValueType]
SubCategoryType = Union[ SubCategoryType = Union[
# Example: entities.domains = { light: … } SubCategoryDict,
Mapping[str, ValueType],
bool, bool,
None None
] ]

View File

@ -0,0 +1,98 @@
"""Helpers to deal with permissions."""
from functools import wraps
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401
from .models import PermissionLookup
from .types import CategoryType, SubCategoryDict, ValueType
LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str],
Optional[ValueType]]
SubCatLookupType = Dict[str, LookupFunc]
def lookup_all(perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict,
object_id: str) -> ValueType:
"""Look up permission for all."""
# In case of ALL category, lookup_dict IS the schema.
return cast(ValueType, lookup_dict)
def compile_policy(
policy: CategoryType, subcategories: SubCatLookupType,
perm_lookup: PermissionLookup
) -> Callable[[str, str], bool]: # noqa
"""Compile policy into a function that tests policy.
Subcategories are mapping key -> lookup function, ordered by highest
priority first.
"""
# None, False, empty dict
if not policy:
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
return apply_policy_deny_all
if policy is True:
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all."""
return True
return apply_policy_allow_all
assert isinstance(policy, dict)
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
for key, lookup_func in subcategories.items():
lookup_value = policy.get(key)
# If any lookup value is `True`, it will always be positive
if isinstance(lookup_value, bool):
return lambda object_id, key: True
if lookup_value is not None:
funcs.append(_gen_dict_test_func(
perm_lookup, lookup_func, lookup_value))
if len(funcs) == 1:
func = funcs[0]
@wraps(func)
def apply_policy_func(object_id: str, key: str) -> bool:
"""Apply a single policy function."""
return func(object_id, key) is True
return apply_policy_func
def apply_policy_funcs(object_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(object_id, key)
if result is not None:
return result
return False
return apply_policy_funcs
def _gen_dict_test_func(
perm_lookup: PermissionLookup,
lookup_func: LookupFunc,
lookup_dict: SubCategoryDict
) -> Callable[[str, str], Optional[bool]]: # noqa
"""Generate a lookup function."""
def test_value(object_id: str, key: str) -> Optional[bool]:
"""Test if permission is allowed based on the keys."""
schema = lookup_func(
perm_lookup, lookup_dict, object_id) # type: ValueType
if schema is None or isinstance(schema, bool):
return schema
assert isinstance(schema, dict)
return schema.get(key)
return test_value

View File

@ -1,7 +1,7 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
import logging import logging
import uuid import uuid
from typing import List from typing import List, Optional
from collections import OrderedDict from collections import OrderedDict
@ -71,6 +71,11 @@ class DeviceRegistry:
self.devices = None self.devices = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
"""Get device."""
return self.devices.get(device_id)
@callback @callback
def async_get_device(self, identifiers: set, connections: set): def async_get_device(self, identifiers: set, connections: set):
"""Check if device is registered.""" """Check if device is registered."""

View File

@ -6,8 +6,9 @@ from homeassistant.auth.permissions.entities import (
compile_entities, ENTITY_POLICY_SCHEMA) compile_entities, ENTITY_POLICY_SCHEMA)
from homeassistant.auth.permissions.models import PermissionLookup from homeassistant.auth.permissions.models import PermissionLookup
from homeassistant.helpers.entity_registry import RegistryEntry from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.helpers.device_registry import DeviceEntry
from tests.common import mock_registry from tests.common import mock_registry, mock_device_registry
def test_entities_none(): def test_entities_none():
@ -193,7 +194,7 @@ def test_entities_all_control():
def test_entities_device_id_boolean(hass): def test_entities_device_id_boolean(hass):
"""Test entity ID policy applying control on device id.""" """Test entity ID policy applying control on device id."""
registry = mock_registry(hass, { entity_registry = mock_registry(hass, {
'test_domain.allowed': RegistryEntry( 'test_domain.allowed': RegistryEntry(
entity_id='test_domain.allowed', entity_id='test_domain.allowed',
unique_id='1234', unique_id='1234',
@ -207,6 +208,7 @@ def test_entities_device_id_boolean(hass):
device_id='mock-not-allowed-dev-id' device_id='mock-not-allowed-dev-id'
), ),
}) })
device_registry = mock_device_registry(hass)
policy = { policy = {
'device_ids': { 'device_ids': {
@ -216,8 +218,55 @@ def test_entities_device_id_boolean(hass):
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, PermissionLookup(registry)) compiled = compile_entities(policy, PermissionLookup(
entity_registry, device_registry
))
assert compiled('test_domain.allowed', 'read') is True assert compiled('test_domain.allowed', 'read') is True
assert compiled('test_domain.allowed', 'control') is False assert compiled('test_domain.allowed', 'control') is False
assert compiled('test_domain.not_allowed', 'read') is False assert compiled('test_domain.not_allowed', 'read') is False
assert compiled('test_domain.not_allowed', 'control') is False assert compiled('test_domain.not_allowed', 'control') is False
def test_entities_areas_true():
"""Test entity ID policy for areas."""
policy = {
'area_ids': True
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True
def test_entities_areas_area_true(hass):
"""Test entity ID policy for areas with specific area."""
entity_registry = mock_registry(hass, {
'light.kitchen': RegistryEntry(
entity_id='light.kitchen',
unique_id='1234',
platform='test_platform',
device_id='mock-dev-id'
),
})
device_registry = mock_device_registry(hass, {
'mock-dev-id': DeviceEntry(
id='mock-dev-id',
area_id='mock-area-id'
)
})
policy = {
'area_ids': {
'mock-area-id': {
'read': True,
'control': True,
}
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, PermissionLookup(
entity_registry, device_registry
))
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True
assert compiled('light.kitchen', 'edit') is False
assert compiled('switch.kitchen', 'read') is False

View File

@ -245,7 +245,9 @@ async def test_loading_race_condition(hass):
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
with asynctest.patch( with asynctest.patch(
'homeassistant.helpers.entity_registry.async_get_registry', 'homeassistant.helpers.entity_registry.async_get_registry',
) as mock_registry, asynctest.patch( ) as mock_ent_registry, asynctest.patch(
'homeassistant.helpers.device_registry.async_get_registry',
) as mock_dev_registry, asynctest.patch(
'homeassistant.helpers.storage.Store.async_load', 'homeassistant.helpers.storage.Store.async_load',
) as mock_load: ) as mock_load:
results = await asyncio.gather( results = await asyncio.gather(
@ -253,6 +255,7 @@ async def test_loading_race_condition(hass):
store.async_get_users(), store.async_get_users(),
) )
mock_registry.assert_called_once_with(hass) mock_ent_registry.assert_called_once_with(hass)
mock_dev_registry.assert_called_once_with(hass)
mock_load.assert_called_once_with() mock_load.assert_called_once_with()
assert results[0] == results[1] assert results[0] == results[1]