From 65ad39f5be0c080976984bfc07a3c1e83f854d55 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 7 May 2025 09:30:40 +0200 Subject: [PATCH] Modify require_admin decorator to take parameters for Unauthorized (#144346) --- .../components/config/config_entries.py | 36 +++++-------------- homeassistant/components/http/decorators.py | 8 +++-- .../components/repairs/websocket_api.py | 7 ++-- 3 files changed, 17 insertions(+), 34 deletions(-) diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 6e2d4a5da49..d20d4de881f 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -165,9 +165,7 @@ class ConfigManagerFlowIndexView( """Not implemented.""" raise aiohttp.web_exceptions.HTTPMethodNotAllowed("GET", ["POST"]) - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission="add") @RequestDataValidator( vol.Schema( { @@ -218,16 +216,12 @@ class ConfigManagerFlowResourceView( url = "/api/config/config_entries/flow/{flow_id}" name = "api:config:config_entries:flow:resource" - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission="add") async def get(self, request: web.Request, /, flow_id: str) -> web.Response: """Get the current state of a data_entry_flow.""" return await super().get(request, flow_id) - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="add") - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission="add") async def post(self, request: web.Request, flow_id: str) -> web.Response: """Handle a POST request.""" return await super().post(request, flow_id) @@ -262,9 +256,7 @@ class OptionManagerFlowIndexView( url = "/api/config/config_entries/options/flow" name = "api:config:config_entries:option:flow" - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) async def post(self, request: web.Request) -> web.Response: """Handle a POST request. @@ -281,16 +273,12 @@ class OptionManagerFlowResourceView( url = "/api/config/config_entries/options/flow/{flow_id}" name = "api:config:config_entries:options:flow:resource" - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) async def get(self, request: web.Request, /, flow_id: str) -> web.Response: """Get the current state of a data_entry_flow.""" return await super().get(request, flow_id) - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) async def post(self, request: web.Request, flow_id: str) -> web.Response: """Handle a POST request.""" return await super().post(request, flow_id) @@ -304,9 +292,7 @@ class SubentryManagerFlowIndexView( url = "/api/config/config_entries/subentries/flow" name = "api:config:config_entries:subentries:flow" - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) @RequestDataValidator( vol.Schema( { @@ -341,16 +327,12 @@ class SubentryManagerFlowResourceView( url = "/api/config/config_entries/subentries/flow/{flow_id}" name = "api:config:config_entries:subentries:flow:resource" - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) async def get(self, request: web.Request, /, flow_id: str) -> web.Response: """Get the current state of a data_entry_flow.""" return await super().get(request, flow_id) - @require_admin( - error=Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) - ) + @require_admin(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) async def post(self, request: web.Request, flow_id: str) -> web.Response: """Handle a POST request.""" return await super().post(request, flow_id) diff --git a/homeassistant/components/http/decorators.py b/homeassistant/components/http/decorators.py index 1adc21be09f..19a0a5d1c55 100644 --- a/homeassistant/components/http/decorators.py +++ b/homeassistant/components/http/decorators.py @@ -27,7 +27,8 @@ def require_admin[ ]( _func: None = None, *, - error: Unauthorized | None = None, + perm_category: str | None = None, + permission: str | None = None, ) -> Callable[ [_FuncType[_HomeAssistantViewT, _P, _ResponseT]], _FuncType[_HomeAssistantViewT, _P, _ResponseT], @@ -51,7 +52,8 @@ def require_admin[ ]( _func: _FuncType[_HomeAssistantViewT, _P, _ResponseT] | None = None, *, - error: Unauthorized | None = None, + perm_category: str | None = None, + permission: str | None = None, ) -> ( Callable[ [_FuncType[_HomeAssistantViewT, _P, _ResponseT]], @@ -76,7 +78,7 @@ def require_admin[ """Check admin and call function.""" user: User = request["hass_user"] if not user.is_admin: - raise error or Unauthorized() + raise Unauthorized(perm_category=perm_category, permission=permission) return await func(self, request, *args, **kwargs) diff --git a/homeassistant/components/repairs/websocket_api.py b/homeassistant/components/repairs/websocket_api.py index 4875a8f6cfa..4117b0ee35b 100644 --- a/homeassistant/components/repairs/websocket_api.py +++ b/homeassistant/components/repairs/websocket_api.py @@ -14,7 +14,6 @@ from homeassistant.components import websocket_api from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.decorators import require_admin from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import Unauthorized from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.data_entry_flow import ( FlowManagerIndexView, @@ -114,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView): url = "/api/repairs/issues/fix" name = "api:repairs:issues:fix" - @require_admin(error=Unauthorized(permission=POLICY_EDIT)) + @require_admin(permission=POLICY_EDIT) @RequestDataValidator( vol.Schema( { @@ -149,12 +148,12 @@ class RepairsFlowResourceView(FlowManagerResourceView): url = "/api/repairs/issues/fix/{flow_id}" name = "api:repairs:issues:fix:resource" - @require_admin(error=Unauthorized(permission=POLICY_EDIT)) + @require_admin(permission=POLICY_EDIT) async def get(self, request: web.Request, /, flow_id: str) -> web.Response: """Get the current state of a data_entry_flow.""" return await super().get(request, flow_id) - @require_admin(error=Unauthorized(permission=POLICY_EDIT)) + @require_admin(permission=POLICY_EDIT) async def post(self, request: web.Request, flow_id: str) -> web.Response: """Handle a POST request.""" return await super().post(request, flow_id)