From b0f68f1ef30cb350b6011460ff955c3032599db7 Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Mon, 14 Aug 2023 15:07:20 +0200 Subject: [PATCH] Use @require_admin decorator (#98061) Co-authored-by: Robert Resch Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --- homeassistant/components/api/__init__.py | 14 ++-- .../components/config/config_entries.py | 46 ++++++------- homeassistant/components/http/decorators.py | 66 +++++++++++++++---- .../components/media_source/local_source.py | 6 +- .../components/repairs/websocket_api.py | 17 ++--- homeassistant/components/zwave_js/api.py | 5 +- homeassistant/helpers/data_entry_flow.py | 2 +- .../components/config/test_config_entries.py | 46 +++++++++++++ 8 files changed, 136 insertions(+), 66 deletions(-) diff --git a/homeassistant/components/api/__init__.py b/homeassistant/components/api/__init__.py index b465a6b7037..f264806ad47 100644 --- a/homeassistant/components/api/__init__.py +++ b/homeassistant/components/api/__init__.py @@ -11,7 +11,7 @@ import voluptuous as vol from homeassistant.auth.permissions.const import POLICY_READ from homeassistant.bootstrap import DATA_LOGGING -from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http import HomeAssistantView, require_admin from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, MATCH_ALL, @@ -110,10 +110,9 @@ class APIEventStream(HomeAssistantView): url = URL_API_STREAM name = "api:stream" + @require_admin async def get(self, request): """Provide a streaming interface for the event bus.""" - if not request["hass_user"].is_admin: - raise Unauthorized() hass = request.app["hass"] stop_obj = object() to_write = asyncio.Queue() @@ -278,10 +277,9 @@ class APIEventView(HomeAssistantView): url = "/api/events/{event_type}" name = "api:event" + @require_admin async def post(self, request, event_type): """Fire events.""" - if not request["hass_user"].is_admin: - raise Unauthorized() body = await request.text() try: event_data = json_loads(body) if body else None @@ -385,10 +383,9 @@ class APITemplateView(HomeAssistantView): url = URL_API_TEMPLATE name = "api:template" + @require_admin async def post(self, request): """Render a template.""" - if not request["hass_user"].is_admin: - raise Unauthorized() try: data = await request.json() tpl = _cached_template(data["template"], request.app["hass"]) @@ -405,10 +402,9 @@ class APIErrorLog(HomeAssistantView): url = URL_API_ERROR_LOG name = "api:error_log" + @require_admin async def get(self, request): """Retrieve API error log.""" - if not request["hass_user"].is_admin: - raise Unauthorized() return web.FileResponse(request.app["hass"].data[DATA_LOGGING]) diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index d58616ff38f..9691994512c 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -11,7 +11,7 @@ import voluptuous as vol from homeassistant import config_entries, data_entry_flow from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT from homeassistant.components import websocket_api -from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http import HomeAssistantView, require_admin from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import DependencyError, Unauthorized import homeassistant.helpers.config_validation as cv @@ -138,12 +138,11 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView): """Not implemented.""" raise aiohttp.web_exceptions.HTTPMethodNotAllowed("GET", ["POST"]) - # pylint: disable=arguments-differ + @require_admin( + error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") + ) async def post(self, request): """Handle a POST request.""" - if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") - # pylint: disable=no-value-for-parameter try: return await super().post(request) @@ -164,19 +163,18 @@ class ConfigManagerFlowResourceView(FlowManagerResourceView): url = "/api/config/config_entries/flow/{flow_id}" name = "api:config:config_entries:flow:resource" - async def get(self, request, flow_id): + @require_admin( + error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") + ) + async def get(self, request, /, flow_id): """Get the current state of a data_entry_flow.""" - if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") - return await super().get(request, flow_id) - # pylint: disable=arguments-differ + @require_admin( + error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") + ) async def post(self, request, flow_id): """Handle a POST request.""" - if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") - # pylint: disable=no-value-for-parameter return await super().post(request, flow_id) @@ -206,15 +204,14 @@ class OptionManagerFlowIndexView(FlowManagerIndexView): url = "/api/config/config_entries/options/flow" name = "api:config:config_entries:option:flow" - # pylint: disable=arguments-differ + @require_admin( + error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) + ) async def post(self, request): """Handle a POST request. handler in request is entry_id. """ - if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - # pylint: disable=no-value-for-parameter return await super().post(request) @@ -225,19 +222,18 @@ class OptionManagerFlowResourceView(FlowManagerResourceView): url = "/api/config/config_entries/options/flow/{flow_id}" name = "api:config:config_entries:options:flow:resource" - async def get(self, request, flow_id): + @require_admin( + error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) + ) + async def get(self, request, /, flow_id): """Get the current state of a data_entry_flow.""" - if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - return await super().get(request, flow_id) - # pylint: disable=arguments-differ + @require_admin( + error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) + ) async def post(self, request, flow_id): """Handle a POST request.""" - if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - # pylint: disable=no-value-for-parameter return await super().post(request, flow_id) diff --git a/homeassistant/components/http/decorators.py b/homeassistant/components/http/decorators.py index 45bd34fa49f..ce5b1b18c06 100644 --- a/homeassistant/components/http/decorators.py +++ b/homeassistant/components/http/decorators.py @@ -1,8 +1,9 @@ """Decorators for the Home Assistant API.""" from __future__ import annotations -from collections.abc import Awaitable, Callable -from typing import Concatenate, ParamSpec, TypeVar +from collections.abc import Callable, Coroutine +from functools import wraps +from typing import Any, Concatenate, ParamSpec, TypeVar, overload from aiohttp.web import Request, Response @@ -12,20 +13,61 @@ from .view import HomeAssistantView _HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView) _P = ParamSpec("_P") +_FuncType = Callable[ + Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, Response] +] + + +@overload +def require_admin( + _func: None = None, + *, + error: Unauthorized | None = None, +) -> Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]: + ... + + +@overload +def require_admin( + _func: _FuncType[_HomeAssistantViewT, _P], +) -> _FuncType[_HomeAssistantViewT, _P]: + ... def require_admin( - func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]] -) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]: + _func: _FuncType[_HomeAssistantViewT, _P] | None = None, + *, + error: Unauthorized | None = None, +) -> ( + Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]] + | _FuncType[_HomeAssistantViewT, _P] +): """Home Assistant API decorator to require user to be an admin.""" - async def with_admin( - self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs - ) -> Response: - """Check admin and call function.""" - if not request["hass_user"].is_admin: - raise Unauthorized() + def decorator_require_admin( + func: _FuncType[_HomeAssistantViewT, _P] + ) -> _FuncType[_HomeAssistantViewT, _P]: + """Wrap the provided with_admin function.""" - return await func(self, request, *args, **kwargs) + @wraps(func) + async def with_admin( + self: _HomeAssistantViewT, + request: Request, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> Response: + """Check admin and call function.""" + if not request["hass_user"].is_admin: + raise error or Unauthorized() - return with_admin + return await func(self, request, *args, **kwargs) + + return with_admin + + # See if we're being called as @require_admin or @require_admin(). + if _func is None: + # We're called with brackets. + return decorator_require_admin + + # We're called as @require_admin without brackets. + return decorator_require_admin(_func) diff --git a/homeassistant/components/media_source/local_source.py b/homeassistant/components/media_source/local_source.py index 89437a6b2e0..ac6623a3af8 100644 --- a/homeassistant/components/media_source/local_source.py +++ b/homeassistant/components/media_source/local_source.py @@ -12,9 +12,9 @@ from aiohttp.web_request import FileField import voluptuous as vol from homeassistant.components import http, websocket_api +from homeassistant.components.http import require_admin from homeassistant.components.media_player import BrowseError, MediaClass from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import Unauthorized from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES @@ -254,11 +254,9 @@ class UploadMediaView(http.HomeAssistantView): } ) + @require_admin async def post(self, request: web.Request) -> web.Response: """Handle upload.""" - if not request["hass_user"].is_admin: - raise Unauthorized() - # Increase max payload request._client_max_size = MAX_UPLOAD_SIZE # pylint: disable=protected-access diff --git a/homeassistant/components/repairs/websocket_api.py b/homeassistant/components/repairs/websocket_api.py index c5408054318..0c6230e4c35 100644 --- a/homeassistant/components/repairs/websocket_api.py +++ b/homeassistant/components/repairs/websocket_api.py @@ -12,6 +12,7 @@ from homeassistant import data_entry_flow from homeassistant.auth.permissions.const import POLICY_EDIT from homeassistant.components import websocket_api from homeassistant.components.http.data_validator import RequestDataValidator +from homeassistant.components.http.decorators import require_admin from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import Unauthorized from homeassistant.helpers.data_entry_flow import ( @@ -88,6 +89,7 @@ class RepairsFlowIndexView(FlowManagerIndexView): url = "/api/repairs/issues/fix" name = "api:repairs:issues:fix" + @require_admin(error=Unauthorized(permission=POLICY_EDIT)) @RequestDataValidator( vol.Schema( { @@ -99,9 +101,6 @@ class RepairsFlowIndexView(FlowManagerIndexView): ) async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: """Handle a POST request.""" - if not request["hass_user"].is_admin: - raise Unauthorized(permission=POLICY_EDIT) - try: result = await self._flow_mgr.async_init( data["handler"], @@ -125,18 +124,12 @@ class RepairsFlowResourceView(FlowManagerResourceView): url = "/api/repairs/issues/fix/{flow_id}" name = "api:repairs:issues:fix:resource" - async def get(self, request: web.Request, flow_id: str) -> web.Response: + @require_admin(error=Unauthorized(permission=POLICY_EDIT)) + async def get(self, request: web.Request, /, flow_id: str) -> web.Response: """Get the current state of a data_entry_flow.""" - if not request["hass_user"].is_admin: - raise Unauthorized(permission=POLICY_EDIT) - return await super().get(request, flow_id) - # pylint: disable=arguments-differ + @require_admin(error=Unauthorized(permission=POLICY_EDIT)) async def post(self, request: web.Request, flow_id: str) -> web.Response: """Handle a POST request.""" - if not request["hass_user"].is_admin: - raise Unauthorized(permission=POLICY_EDIT) - - # pylint: disable=no-value-for-parameter return await super().post(request, flow_id) diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index 6d2461df3e4..6781ccacdc7 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -55,6 +55,7 @@ from zwave_js_server.model.utils import ( from zwave_js_server.util.node import async_set_config_parameter from homeassistant.components import websocket_api +from homeassistant.components.http import require_admin from homeassistant.components.http.view import HomeAssistantView from homeassistant.components.websocket_api import ( ERR_INVALID_FORMAT, @@ -65,7 +66,6 @@ from homeassistant.components.websocket_api import ( ) from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import Unauthorized from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.device_registry as dr @@ -2149,10 +2149,9 @@ class FirmwareUploadView(HomeAssistantView): super().__init__() self._dev_reg = dev_reg + @require_admin async def post(self, request: web.Request, device_id: str) -> web.Response: """Handle upload.""" - if not request["hass_user"].is_admin: - raise Unauthorized() hass = request.app["hass"] try: diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index e3e4b4f0de8..aa4ef36b251 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -90,7 +90,7 @@ class FlowManagerIndexView(_BaseFlowManagerView): class FlowManagerResourceView(_BaseFlowManagerView): """View to interact with the flow manager.""" - async def get(self, request: web.Request, flow_id: str) -> web.Response: + async def get(self, request: web.Request, /, flow_id: str) -> web.Response: """Get the current state of a data_entry_flow.""" try: result = await self._flow_mgr.async_configure(flow_id) diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index bf94e36a9b4..4684b4148b1 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -825,6 +825,52 @@ async def test_options_flow(hass: HomeAssistant, client) -> None: } +@pytest.mark.parametrize( + ("endpoint", "method"), + [ + ("/api/config/config_entries/options/flow", "post"), + ("/api/config/config_entries/options/flow/1", "get"), + ("/api/config/config_entries/options/flow/1", "post"), + ], +) +async def test_options_flow_unauth( + hass: HomeAssistant, client, hass_admin_user: MockUser, endpoint: str, method: str +) -> None: + """Test unauthorized on options flow.""" + + class TestFlow(core_ce.ConfigFlow): + @staticmethod + @callback + def async_get_options_flow(config_entry): + class OptionsFlowHandler(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + schema = OrderedDict() + schema[vol.Required("enabled")] = bool + return self.async_show_form( + step_id="user", + data_schema=schema, + description_placeholders={"enabled": "Set to true to be true"}, + ) + + return OptionsFlowHandler() + + mock_integration(hass, MockModule("test")) + mock_entity_platform(hass, "config_flow.test", None) + MockConfigEntry( + domain="test", + entry_id="test1", + source="bla", + ).add_to_hass(hass) + entry = hass.config_entries.async_entries()[0] + + hass_admin_user.groups = [] + + with patch.dict(HANDLERS, {"test": TestFlow}): + resp = await getattr(client, method)(endpoint, json={"handler": entry.entry_id}) + + assert resp.status == HTTPStatus.UNAUTHORIZED + + async def test_two_step_options_flow(hass: HomeAssistant, client) -> None: """Test we can finish a two step options flow.""" mock_integration(