diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index b685e0d67c7..43b8318abc5 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -1,7 +1,9 @@ """Service calling related helpers.""" import asyncio +from functools import wraps import logging from os import path +from typing import Callable import voluptuous as vol @@ -10,7 +12,7 @@ from homeassistant.const import ( ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ATTR_AREA_ID) import homeassistant.core as ha from homeassistant.exceptions import TemplateError, Unauthorized, UnknownUser -from homeassistant.helpers import template +from homeassistant.helpers import template, typing from homeassistant.loader import get_component, bind_hass from homeassistant.util.yaml import load_yaml import homeassistant.helpers.config_validation as cv @@ -335,3 +337,25 @@ async def _handle_service_platform_call(func, data, entities, context): assert not pending for future in done: future.result() # pop exception if have + + +@bind_hass +@ha.callback +def async_register_admin_service(hass: typing.HomeAssistantType, domain: str, + service: str, service_func: Callable, + schema: vol.Schema) -> None: + """Register a service that requires admin access.""" + @wraps(service_func) + async def admin_handler(call): + if call.context.user_id: + user = await hass.auth.async_get_user(call.context.user_id) + if user is None: + raise UnknownUser(context=call.context) + if not user.is_admin: + raise Unauthorized(context=call.context) + + await hass.async_add_job(service_func, call) + + hass.services.async_register( + domain, service, admin_handler, schema + ) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 854ee9c74f6..a36785b6ba0 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -5,18 +5,18 @@ from copy import deepcopy import unittest from unittest.mock import Mock, patch +import voluptuous as vol import pytest # To prevent circular import when running just this file import homeassistant.components # noqa from homeassistant import core as ha, loader, exceptions from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ENTITY_ID -from homeassistant.helpers import service, template from homeassistant.setup import async_setup_component import homeassistant.helpers.config_validation as cv from homeassistant.auth.permissions import PolicyPermissions from homeassistant.helpers import ( - device_registry as dev_reg, entity_registry as ent_reg) + service, template, device_registry as dev_reg, entity_registry as ent_reg) from tests.common import ( get_test_home_assistant, mock_service, mock_coro, mock_registry, mock_device_registry) @@ -395,3 +395,37 @@ async def test_call_with_omit_entity_id(hass, mock_service_platform_call, mock_entities['light.kitchen'], mock_entities['light.living_room']] assert ('Not passing an entity ID to a service to target ' 'all entities is deprecated') in caplog.text + + +async def test_register_admin_service(hass, hass_read_only_user, + hass_admin_user): + """Test the register admin service.""" + calls = [] + + async def mock_service(call): + calls.append(call) + + hass.helpers.service.async_register_admin_service( + 'test', 'test', mock_service, vol.Schema({}) + ) + + with pytest.raises(exceptions.UnknownUser): + await hass.services.async_call( + 'test', 'test', {}, blocking=True, context=ha.Context( + user_id='non-existing' + )) + assert len(calls) == 0 + + with pytest.raises(exceptions.Unauthorized): + await hass.services.async_call( + 'test', 'test', {}, blocking=True, context=ha.Context( + user_id=hass_read_only_user.id + )) + assert len(calls) == 0 + + await hass.services.async_call( + 'test', 'test', {}, blocking=True, context=ha.Context( + user_id=hass_admin_user.id + )) + assert len(calls) == 1 + assert calls[0].context.user_id == hass_admin_user.id