diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index 84c348236d4..00ef4455f3b 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -24,7 +24,7 @@ from homeassistant.components.alexa import ( ) from homeassistant.components.google_assistant import helpers as google_helpers from homeassistant.components.homeassistant import exposed_entities -from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http import HomeAssistantView, require_admin from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES from homeassistant.core import HomeAssistant @@ -128,7 +128,6 @@ def _handle_cloud_errors( try: result = await handler(view, request, *args, **kwargs) return result - except Exception as err: # pylint: disable=broad-except status, msg = _process_cloud_exception(err, request.path) return view.json_message( @@ -188,6 +187,7 @@ class GoogleActionsSyncView(HomeAssistantView): url = "/api/cloud/google_actions/sync" name = "api:cloud:google_actions/sync" + @require_admin @_handle_cloud_errors async def post(self, request: web.Request) -> web.Response: """Trigger a Google Actions sync.""" @@ -204,6 +204,7 @@ class CloudLoginView(HomeAssistantView): url = "/api/cloud/login" name = "api:cloud:login" + @require_admin @_handle_cloud_errors @RequestDataValidator( vol.Schema({vol.Required("email"): str, vol.Required("password"): str}) @@ -244,6 +245,7 @@ class CloudLogoutView(HomeAssistantView): url = "/api/cloud/logout" name = "api:cloud:logout" + @require_admin @_handle_cloud_errors async def post(self, request: web.Request) -> web.Response: """Handle logout request.""" @@ -262,6 +264,7 @@ class CloudRegisterView(HomeAssistantView): url = "/api/cloud/register" name = "api:cloud:register" + @require_admin @_handle_cloud_errors @RequestDataValidator( vol.Schema( @@ -305,6 +308,7 @@ class CloudResendConfirmView(HomeAssistantView): url = "/api/cloud/resend_confirm" name = "api:cloud:resend_confirm" + @require_admin @_handle_cloud_errors @RequestDataValidator(vol.Schema({vol.Required("email"): str})) async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: @@ -324,6 +328,7 @@ class CloudForgotPasswordView(HomeAssistantView): url = "/api/cloud/forgot_password" name = "api:cloud:forgot_password" + @require_admin @_handle_cloud_errors @RequestDataValidator(vol.Schema({vol.Required("email"): str})) async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 48ad0cb8752..ff287efb083 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -52,6 +52,7 @@ from .const import ( # noqa: F401 KEY_HASS_USER, ) from .cors import setup_cors +from .decorators import require_admin # noqa: F401 from .forwarded import async_setup_forwarded from .headers import setup_headers from .request_context import current_request, setup_request_context diff --git a/homeassistant/components/http/decorators.py b/homeassistant/components/http/decorators.py new file mode 100644 index 00000000000..45bd34fa49f --- /dev/null +++ b/homeassistant/components/http/decorators.py @@ -0,0 +1,31 @@ +"""Decorators for the Home Assistant API.""" +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Concatenate, ParamSpec, TypeVar + +from aiohttp.web import Request, Response + +from homeassistant.exceptions import Unauthorized + +from .view import HomeAssistantView + +_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView) +_P = ParamSpec("_P") + + +def require_admin( + func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]] +) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]: + """Home Assistant API decorator to require user to be an admin.""" + + async def with_admin( + self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs + ) -> Response: + """Check admin and call function.""" + if not request["hass_user"].is_admin: + raise Unauthorized() + + return await func(self, request, *args, **kwargs) + + return with_admin diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index ff79fd1ea77..fc6861f2b49 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -1,6 +1,7 @@ """Tests for the HTTP API for the cloud component.""" import asyncio from http import HTTPStatus +from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch import aiohttp @@ -24,7 +25,7 @@ from . import mock_cloud, mock_cloud_prefs from tests.components.google_assistant import MockConfig from tests.test_util.aiohttp import AiohttpClientMocker -from tests.typing import WebSocketGenerator +from tests.typing import ClientSessionGenerator, WebSocketGenerator SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/payments/subscription_info" @@ -1207,3 +1208,28 @@ async def test_tts_info( assert response["success"] assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]} + + +@pytest.mark.parametrize( + ("endpoint", "data"), + [ + ("/api/cloud/forgot_password", {"email": "fake@example.com"}), + ("/api/cloud/google_actions/sync", None), + ("/api/cloud/login", {"email": "fake@example.com", "password": "secret"}), + ("/api/cloud/logout", None), + ("/api/cloud/register", {"email": "fake@example.com", "password": "secret"}), + ("/api/cloud/resend_confirm", {"email": "fake@example.com"}), + ], +) +async def test_api_calls_require_admin( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + hass_read_only_access_token: str, + endpoint: str, + data: dict[str, Any] | None, +) -> None: + """Test cloud APIs endpoints do not work as a normal user.""" + client = await hass_client(hass_read_only_access_token) + resp = await client.post(endpoint, json=data) + + assert resp.status == HTTPStatus.UNAUTHORIZED