Add a service require_admin wrapper (#21953)

* Add a service require_admin wrapper

* Allow it to be used as a decorator

* Lint

* Add comment

* Add docstring

* Update syntax
This commit is contained in:
Paulus Schoutsen 2019-03-12 22:09:50 -07:00 committed by GitHub
parent bf839687ad
commit c15f433c3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 3 deletions

View File

@ -1,7 +1,9 @@
"""Service calling related helpers.""" """Service calling related helpers."""
import asyncio import asyncio
from functools import wraps
import logging import logging
from os import path from os import path
from typing import Callable
import voluptuous as vol import voluptuous as vol
@ -10,7 +12,7 @@ from homeassistant.const import (
ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID) ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID)
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser
from homeassistant.helpers import template from homeassistant.helpers import template, typing
from homeassistant.loader import get_component, bind_hass from homeassistant.loader import get_component, bind_hass
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -335,3 +337,25 @@ async def _handle_service_platform_call(func, data, entities, context):
assert not pending assert not pending
for future in done: for future in done:
future.result() # pop exception if have future.result() # pop exception if have
@bind_hass
@ha.callback
def async_register_admin_service(hass: typing.HomeAssistantType, domain: str,
service: str, service_func: Callable,
schema: vol.Schema) -> None:
"""Register a service that requires admin access."""
@wraps(service_func)
async def admin_handler(call):
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
if not user.is_admin:
raise Unauthorized(context=call.context)
await hass.async_add_job(service_func, call)
hass.services.async_register(
domain, service, admin_handler, schema
)

View File

@ -5,18 +5,18 @@ from copy import deepcopy
import unittest import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import voluptuous as vol
import pytest import pytest
# To prevent circular import when running just this file # To prevent circular import when running just this file
import homeassistant.components # noqa import homeassistant.components # noqa
from homeassistant import core as ha, loader, exceptions from homeassistant import core as ha, loader, exceptions
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID
from homeassistant.helpers import service, template
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.auth.permissions import PolicyPermissions from homeassistant.auth.permissions import PolicyPermissions
from homeassistant.helpers import ( from homeassistant.helpers import (
device_registry as dev_reg, entity_registry as ent_reg) service, template, device_registry as dev_reg, entity_registry as ent_reg)
from tests.common import ( from tests.common import (
get_test_home_assistant, mock_service, mock_coro, mock_registry, get_test_home_assistant, mock_service, mock_coro, mock_registry,
mock_device_registry) mock_device_registry)
@ -395,3 +395,37 @@ async def test_call_with_omit_entity_id(hass, mock_service_platform_call,
mock_entities['light.kitchen'], mock_entities['light.living_room']] mock_entities['light.kitchen'], mock_entities['light.living_room']]
assert ('Not passing an entity ID to a service to target ' assert ('Not passing an entity ID to a service to target '
'all entities is deprecated') in caplog.text 'all entities is deprecated') in caplog.text
async def test_register_admin_service(hass, hass_read_only_user,
hass_admin_user):
"""Test the register admin service."""
calls = []
async def mock_service(call):
calls.append(call)
hass.helpers.service.async_register_admin_service(
'test', 'test', mock_service, vol.Schema({})
)
with pytest.raises(exceptions.UnknownUser):
await hass.services.async_call(
'test', 'test', {}, blocking=True, context=ha.Context(
user_id='non-existing'
))
assert len(calls) == 0
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
'test', 'test', {}, blocking=True, context=ha.Context(
user_id=hass_read_only_user.id
))
assert len(calls) == 0
await hass.services.async_call(
'test', 'test', {}, blocking=True, context=ha.Context(
user_id=hass_admin_user.id
))
assert len(calls) == 1
assert calls[0].context.user_id == hass_admin_user.id