Allow checking entity permissions based on devices (#19007)

* Allow checking entity permissions based on devices

* Fix tests
This commit is contained in:
Paulus Schoutsen 2018-12-05 11:41:00 +01:00 committed by Pascal Vizeli
parent 2680bf8a61
commit 3928d034a3
11 changed files with 143 additions and 27 deletions

View File

@ -1,4 +1,5 @@
"""Storage for auth models.""" """Storage for auth models."""
import asyncio
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
import hmac import hmac
@ -11,7 +12,7 @@ from homeassistant.util import dt as dt_util
from . import models from . import models
from .const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY from .const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from .permissions import system_policies from .permissions import PermissionLookup, system_policies
from .permissions.types import PolicyType # noqa: F401 from .permissions.types import PolicyType # noqa: F401
STORAGE_VERSION = 1 STORAGE_VERSION = 1
@ -34,6 +35,7 @@ class AuthStore:
self.hass = hass self.hass = hass
self._users = None # type: Optional[Dict[str, models.User]] self._users = None # type: Optional[Dict[str, models.User]]
self._groups = None # type: Optional[Dict[str, models.Group]] self._groups = None # type: Optional[Dict[str, models.Group]]
self._perm_lookup = None # type: Optional[PermissionLookup]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY, self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY,
private=True) private=True)
@ -94,6 +96,7 @@ class AuthStore:
# Until we get group management, we just put everyone in the # Until we get group management, we just put everyone in the
# same group. # same group.
'groups': groups, 'groups': groups,
'perm_lookup': self._perm_lookup,
} # type: Dict[str, Any] } # type: Dict[str, Any]
if is_owner is not None: if is_owner is not None:
@ -269,13 +272,18 @@ class AuthStore:
async def _async_load(self) -> None: async def _async_load(self) -> None:
"""Load the users.""" """Load the users."""
data = await self._store.async_load() [ent_reg, data] = await asyncio.gather(
self.hass.helpers.entity_registry.async_get_registry(),
self._store.async_load(),
)
# Make sure that we're not overriding data if 2 loads happened at the # Make sure that we're not overriding data if 2 loads happened at the
# same time # same time
if self._users is not None: if self._users is not None:
return return
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg)
if data is None: if data is None:
self._set_defaults() self._set_defaults()
return return
@ -374,6 +382,7 @@ class AuthStore:
is_owner=user_dict['is_owner'], is_owner=user_dict['is_owner'],
is_active=user_dict['is_active'], is_active=user_dict['is_active'],
system_generated=user_dict['system_generated'], system_generated=user_dict['system_generated'],
perm_lookup=perm_lookup,
) )
for cred_dict in data['credentials']: for cred_dict in data['credentials']:

View File

@ -31,6 +31,9 @@ class User:
"""A user.""" """A user."""
name = attr.ib(type=str) # type: Optional[str] name = attr.ib(type=str) # type: Optional[str]
perm_lookup = attr.ib(
type=perm_mdl.PermissionLookup, cmp=False,
) # type: perm_mdl.PermissionLookup
id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex) id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
is_owner = attr.ib(type=bool, default=False) is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False)
@ -66,7 +69,8 @@ class User:
self._permissions = perm_mdl.PolicyPermissions( self._permissions = perm_mdl.PolicyPermissions(
perm_mdl.merge_policies([ perm_mdl.merge_policies([
group.policy for group in self.groups])) group.policy for group in self.groups]),
self.perm_lookup)
return self._permissions return self._permissions

View File

@ -1,15 +1,18 @@
"""Permissions for Home Assistant.""" """Permissions for Home Assistant."""
import logging import logging
from typing import ( # noqa: F401 from typing import ( # noqa: F401
cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union) cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union,
TYPE_CHECKING)
import voluptuous as vol import voluptuous as vol
from .const import CAT_ENTITIES from .const import CAT_ENTITIES
from .models import PermissionLookup
from .types import PolicyType from .types import PolicyType
from .entities import ENTITY_POLICY_SCHEMA, compile_entities from .entities import ENTITY_POLICY_SCHEMA, compile_entities
from .merge import merge_policies # noqa from .merge import merge_policies # noqa
POLICY_SCHEMA = vol.Schema({ POLICY_SCHEMA = vol.Schema({
vol.Optional(CAT_ENTITIES): ENTITY_POLICY_SCHEMA vol.Optional(CAT_ENTITIES): ENTITY_POLICY_SCHEMA
}) })
@ -39,13 +42,16 @@ class AbstractPermissions:
class PolicyPermissions(AbstractPermissions): class PolicyPermissions(AbstractPermissions):
"""Handle permissions.""" """Handle permissions."""
def __init__(self, policy: PolicyType) -> None: def __init__(self, policy: PolicyType,
perm_lookup: PermissionLookup) -> None:
"""Initialize the permission class.""" """Initialize the permission class."""
self._policy = policy self._policy = policy
self._perm_lookup = perm_lookup
def _entity_func(self) -> Callable[[str, str], bool]: def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access.""" """Return a function that can test entity access."""
return compile_entities(self._policy.get(CAT_ENTITIES)) return compile_entities(self._policy.get(CAT_ENTITIES),
self._perm_lookup)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Equals check.""" """Equals check."""

View File

@ -5,6 +5,7 @@ from typing import Callable, List, Union # noqa: F401
import voluptuous as vol import voluptuous as vol
from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT
from .models import PermissionLookup
from .types import CategoryType, ValueType from .types import CategoryType, ValueType
SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
@ -14,6 +15,7 @@ SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({
})) }))
ENTITY_DOMAINS = 'domains' ENTITY_DOMAINS = 'domains'
ENTITY_DEVICE_IDS = 'device_ids'
ENTITY_ENTITY_IDS = 'entity_ids' ENTITY_ENTITY_IDS = 'entity_ids'
ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
@ -22,6 +24,7 @@ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({
ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA, vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA,
vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA,
})) }))
@ -36,7 +39,7 @@ def _entity_allowed(schema: ValueType, key: str) \
return schema.get(key) return schema.get(key)
def compile_entities(policy: CategoryType) \ def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \
-> Callable[[str, str], bool]: -> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy.""" """Compile policy into a function that tests policy."""
# None, Empty Dict, False # None, Empty Dict, False
@ -57,6 +60,7 @@ def compile_entities(policy: CategoryType) \
assert isinstance(policy, dict) assert isinstance(policy, dict)
domains = policy.get(ENTITY_DOMAINS) domains = policy.get(ENTITY_DOMAINS)
device_ids = policy.get(ENTITY_DEVICE_IDS)
entity_ids = policy.get(ENTITY_ENTITY_IDS) entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL) all_entities = policy.get(SUBCAT_ALL)
@ -84,6 +88,29 @@ def compile_entities(policy: CategoryType) \
funcs.append(allowed_entity_id_dict) funcs.append(allowed_entity_id_dict)
if isinstance(device_ids, bool):
def allowed_device_id_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
return device_ids
funcs.append(allowed_device_id_bool)
elif device_ids is not None:
def allowed_device_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed device_id."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None:
return None
return _entity_allowed(
device_ids.get(entity_entry.device_id), key # type: ignore
)
funcs.append(allowed_device_id_dict)
if isinstance(domains, bool): if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, key: str) \ def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]: -> Union[None, bool]:

View File

@ -0,0 +1,17 @@
"""Models for permissions."""
from typing import TYPE_CHECKING
import attr
if TYPE_CHECKING:
# pylint: disable=unused-import
from homeassistant.helpers import ( # noqa
entity_registry as ent_reg,
)
@attr.s(slots=True)
class PermissionLookup:
"""Class to hold data for permission lookups."""
entity_registry = attr.ib(type='ent_reg.EntityRegistry')

View File

@ -10,6 +10,7 @@ timer.
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
from typing import Optional
import weakref import weakref
import attr import attr
@ -85,6 +86,11 @@ class EntityRegistry:
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""
return entity_id in self.entities return entity_id in self.entities
@callback
def async_get(self, entity_id: str) -> Optional[RegistryEntry]:
"""Get EntityEntry for an entity_id."""
return self.entities.get(entity_id)
@callback @callback
def async_get_entity_id(self, domain: str, platform: str, unique_id: str): def async_get_entity_id(self, domain: str, platform: str, unique_id: str):
"""Check if an entity_id is currently registered.""" """Check if an entity_id is currently registered."""

View File

@ -4,12 +4,16 @@ import voluptuous as vol
from homeassistant.auth.permissions.entities import ( from homeassistant.auth.permissions.entities import (
compile_entities, ENTITY_POLICY_SCHEMA) compile_entities, ENTITY_POLICY_SCHEMA)
from homeassistant.auth.permissions.models import PermissionLookup
from homeassistant.helpers.entity_registry import RegistryEntry
from tests.common import mock_registry
def test_entities_none(): def test_entities_none():
"""Test entity ID policy.""" """Test entity ID policy."""
policy = None policy = None
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is False assert compiled('light.kitchen', 'read') is False
@ -17,7 +21,7 @@ def test_entities_empty():
"""Test entity ID policy.""" """Test entity ID policy."""
policy = {} policy = {}
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is False assert compiled('light.kitchen', 'read') is False
@ -32,7 +36,7 @@ def test_entities_true():
"""Test entity ID policy.""" """Test entity ID policy."""
policy = True policy = True
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
@ -42,7 +46,7 @@ def test_entities_domains_true():
'domains': True 'domains': True
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
@ -54,7 +58,7 @@ def test_entities_domains_domain_true():
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', 'read') is False assert compiled('switch.kitchen', 'read') is False
@ -76,7 +80,7 @@ def test_entities_entity_ids_true():
'entity_ids': True 'entity_ids': True
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
@ -97,7 +101,7 @@ def test_entities_entity_ids_entity_id_true():
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', 'read') is False assert compiled('switch.kitchen', 'read') is False
@ -123,7 +127,7 @@ def test_entities_control_only():
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is False assert compiled('light.kitchen', 'control') is False
assert compiled('light.kitchen', 'edit') is False assert compiled('light.kitchen', 'edit') is False
@ -140,7 +144,7 @@ def test_entities_read_control():
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True assert compiled('light.kitchen', 'control') is True
assert compiled('light.kitchen', 'edit') is False assert compiled('light.kitchen', 'edit') is False
@ -152,7 +156,7 @@ def test_entities_all_allow():
'all': True 'all': True
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', 'read') is True assert compiled('switch.kitchen', 'read') is True
@ -166,7 +170,7 @@ def test_entities_all_read():
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is True assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is False assert compiled('light.kitchen', 'control') is False
assert compiled('switch.kitchen', 'read') is True assert compiled('switch.kitchen', 'read') is True
@ -180,8 +184,40 @@ def test_entities_all_control():
} }
} }
ENTITY_POLICY_SCHEMA(policy) ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy) compiled = compile_entities(policy, None)
assert compiled('light.kitchen', 'read') is False assert compiled('light.kitchen', 'read') is False
assert compiled('light.kitchen', 'control') is True assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', 'read') is False assert compiled('switch.kitchen', 'read') is False
assert compiled('switch.kitchen', 'control') is True assert compiled('switch.kitchen', 'control') is True
def test_entities_device_id_boolean(hass):
"""Test entity ID policy applying control on device id."""
registry = mock_registry(hass, {
'test_domain.allowed': RegistryEntry(
entity_id='test_domain.allowed',
unique_id='1234',
platform='test_platform',
device_id='mock-allowed-dev-id'
),
'test_domain.not_allowed': RegistryEntry(
entity_id='test_domain.not_allowed',
unique_id='5678',
platform='test_platform',
device_id='mock-not-allowed-dev-id'
),
})
policy = {
'device_ids': {
'mock-allowed-dev-id': {
'read': True,
}
}
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy, PermissionLookup(registry))
assert compiled('test_domain.allowed', 'read') is True
assert compiled('test_domain.allowed', 'control') is False
assert compiled('test_domain.not_allowed', 'read') is False
assert compiled('test_domain.not_allowed', 'control') is False

View File

@ -8,7 +8,7 @@ def test_admin_policy():
# Make sure it's valid # Make sure it's valid
POLICY_SCHEMA(system_policies.ADMIN_POLICY) POLICY_SCHEMA(system_policies.ADMIN_POLICY)
perms = PolicyPermissions(system_policies.ADMIN_POLICY) perms = PolicyPermissions(system_policies.ADMIN_POLICY, None)
assert perms.check_entity('light.kitchen', 'read') assert perms.check_entity('light.kitchen', 'read')
assert perms.check_entity('light.kitchen', 'control') assert perms.check_entity('light.kitchen', 'control')
assert perms.check_entity('light.kitchen', 'edit') assert perms.check_entity('light.kitchen', 'edit')
@ -19,7 +19,7 @@ def test_read_only_policy():
# Make sure it's valid # Make sure it's valid
POLICY_SCHEMA(system_policies.READ_ONLY_POLICY) POLICY_SCHEMA(system_policies.READ_ONLY_POLICY)
perms = PolicyPermissions(system_policies.READ_ONLY_POLICY) perms = PolicyPermissions(system_policies.READ_ONLY_POLICY, None)
assert perms.check_entity('light.kitchen', 'read') assert perms.check_entity('light.kitchen', 'read')
assert not perms.check_entity('light.kitchen', 'control') assert not perms.check_entity('light.kitchen', 'control')
assert not perms.check_entity('light.kitchen', 'edit') assert not perms.check_entity('light.kitchen', 'edit')

View File

@ -5,7 +5,12 @@ from homeassistant.auth import models, permissions
def test_owner_fetching_owner_permissions(): def test_owner_fetching_owner_permissions():
"""Test we fetch the owner permissions for an owner user.""" """Test we fetch the owner permissions for an owner user."""
group = models.Group(name="Test Group", policy={}) group = models.Group(name="Test Group", policy={})
owner = models.User(name="Test User", groups=[group], is_owner=True) owner = models.User(
name="Test User",
perm_lookup=None,
groups=[group],
is_owner=True
)
assert owner.permissions is permissions.OwnerPermissions assert owner.permissions is permissions.OwnerPermissions
@ -25,7 +30,11 @@ def test_permissions_merged():
} }
} }
}) })
user = models.User(name="Test User", groups=[group, group2]) user = models.User(
name="Test User",
perm_lookup=None,
groups=[group, group2]
)
# Make sure we cache instance # Make sure we cache instance
assert user.permissions is user.permissions assert user.permissions is user.permissions

View File

@ -384,6 +384,7 @@ class MockUser(auth_models.User):
'name': name, 'name': name,
'system_generated': system_generated, 'system_generated': system_generated,
'groups': groups or [], 'groups': groups or [],
'perm_lookup': None,
} }
if id is not None: if id is not None:
kwargs['id'] = id kwargs['id'] = id
@ -401,7 +402,8 @@ class MockUser(auth_models.User):
def mock_policy(self, policy): def mock_policy(self, policy):
"""Mock a policy for a user.""" """Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy) self._permissions = auth_permissions.PolicyPermissions(
policy, self.perm_lookup)
async def register_auth_provider(hass, config): async def register_auth_provider(hass, config):

View File

@ -232,7 +232,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call,
'light.kitchen': True 'light.kitchen': True
} }
} }
})))): }, None)))):
await service.entity_service_call(hass, [ await service.entity_service_call(hass, [
Mock(entities=mock_entities) Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', ], Mock(), ha.ServiceCall('test_domain', 'test_service',
@ -253,7 +253,7 @@ async def test_call_context_target_specific(hass, mock_service_platform_call,
'light.kitchen': True 'light.kitchen': True
} }
} }
})))): }, None)))):
await service.entity_service_call(hass, [ await service.entity_service_call(hass, [
Mock(entities=mock_entities) Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', { ], Mock(), ha.ServiceCall('test_domain', 'test_service', {
@ -271,7 +271,7 @@ async def test_call_context_target_specific_no_auth(
with pytest.raises(exceptions.Unauthorized) as err: with pytest.raises(exceptions.Unauthorized) as err:
with patch('homeassistant.auth.AuthManager.async_get_user', with patch('homeassistant.auth.AuthManager.async_get_user',
return_value=mock_coro(Mock( return_value=mock_coro(Mock(
permissions=PolicyPermissions({})))): permissions=PolicyPermissions({}, None)))):
await service.entity_service_call(hass, [ await service.entity_service_call(hass, [
Mock(entities=mock_entities) Mock(entities=mock_entities)
], Mock(), ha.ServiceCall('test_domain', 'test_service', { ], Mock(), ha.ServiceCall('test_domain', 'test_service', {