Add is_admin checks to cloud APIs (#97804)

This commit is contained in:
Franck Nijhof 2023-08-08 11:02:42 +02:00 committed by GitHub
parent 3859d2e2a6
commit 5e020ea354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 3 deletions

View File

@ -24,7 +24,7 @@ from homeassistant.components.alexa import (
) )
from homeassistant.components.google_assistant import helpers as google_helpers from homeassistant.components.google_assistant import helpers as google_helpers
from homeassistant.components.homeassistant import exposed_entities from homeassistant.components.homeassistant import exposed_entities
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView, require_admin
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -128,7 +128,6 @@ def _handle_cloud_errors(
try: try:
result = await handler(view, request, *args, **kwargs) result = await handler(view, request, *args, **kwargs)
return result return result
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
status, msg = _process_cloud_exception(err, request.path) status, msg = _process_cloud_exception(err, request.path)
return view.json_message( return view.json_message(
@ -188,6 +187,7 @@ class GoogleActionsSyncView(HomeAssistantView):
url = "/api/cloud/google_actions/sync" url = "/api/cloud/google_actions/sync"
name = "api:cloud:google_actions/sync" name = "api:cloud:google_actions/sync"
@require_admin
@_handle_cloud_errors @_handle_cloud_errors
async def post(self, request: web.Request) -> web.Response: async def post(self, request: web.Request) -> web.Response:
"""Trigger a Google Actions sync.""" """Trigger a Google Actions sync."""
@ -204,6 +204,7 @@ class CloudLoginView(HomeAssistantView):
url = "/api/cloud/login" url = "/api/cloud/login"
name = "api:cloud:login" name = "api:cloud:login"
@require_admin
@_handle_cloud_errors @_handle_cloud_errors
@RequestDataValidator( @RequestDataValidator(
vol.Schema({vol.Required("email"): str, vol.Required("password"): str}) vol.Schema({vol.Required("email"): str, vol.Required("password"): str})
@ -244,6 +245,7 @@ class CloudLogoutView(HomeAssistantView):
url = "/api/cloud/logout" url = "/api/cloud/logout"
name = "api:cloud:logout" name = "api:cloud:logout"
@require_admin
@_handle_cloud_errors @_handle_cloud_errors
async def post(self, request: web.Request) -> web.Response: async def post(self, request: web.Request) -> web.Response:
"""Handle logout request.""" """Handle logout request."""
@ -262,6 +264,7 @@ class CloudRegisterView(HomeAssistantView):
url = "/api/cloud/register" url = "/api/cloud/register"
name = "api:cloud:register" name = "api:cloud:register"
@require_admin
@_handle_cloud_errors @_handle_cloud_errors
@RequestDataValidator( @RequestDataValidator(
vol.Schema( vol.Schema(
@ -305,6 +308,7 @@ class CloudResendConfirmView(HomeAssistantView):
url = "/api/cloud/resend_confirm" url = "/api/cloud/resend_confirm"
name = "api:cloud:resend_confirm" name = "api:cloud:resend_confirm"
@require_admin
@_handle_cloud_errors @_handle_cloud_errors
@RequestDataValidator(vol.Schema({vol.Required("email"): str})) @RequestDataValidator(vol.Schema({vol.Required("email"): str}))
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
@ -324,6 +328,7 @@ class CloudForgotPasswordView(HomeAssistantView):
url = "/api/cloud/forgot_password" url = "/api/cloud/forgot_password"
name = "api:cloud:forgot_password" name = "api:cloud:forgot_password"
@require_admin
@_handle_cloud_errors @_handle_cloud_errors
@RequestDataValidator(vol.Schema({vol.Required("email"): str})) @RequestDataValidator(vol.Schema({vol.Required("email"): str}))
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:

View File

@ -52,6 +52,7 @@ from .const import ( # noqa: F401
KEY_HASS_USER, KEY_HASS_USER,
) )
from .cors import setup_cors from .cors import setup_cors
from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded from .forwarded import async_setup_forwarded
from .headers import setup_headers from .headers import setup_headers
from .request_context import current_request, setup_request_context from .request_context import current_request, setup_request_context

View File

@ -0,0 +1,31 @@
"""Decorators for the Home Assistant API."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Concatenate, ParamSpec, TypeVar
from aiohttp.web import Request, Response
from homeassistant.exceptions import Unauthorized
from .view import HomeAssistantView
_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
def require_admin(
func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]
) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]:
"""Home Assistant API decorator to require user to be an admin."""
async def with_admin(
self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> Response:
"""Check admin and call function."""
if not request["hass_user"].is_admin:
raise Unauthorized()
return await func(self, request, *args, **kwargs)
return with_admin

View File

@ -1,6 +1,7 @@
"""Tests for the HTTP API for the cloud component.""" """Tests for the HTTP API for the cloud component."""
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import aiohttp import aiohttp
@ -24,7 +25,7 @@ from . import mock_cloud, mock_cloud_prefs
from tests.components.google_assistant import MockConfig from tests.components.google_assistant import MockConfig
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/payments/subscription_info" SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/payments/subscription_info"
@ -1207,3 +1208,28 @@ async def test_tts_info(
assert response["success"] assert response["success"]
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]} assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}
@pytest.mark.parametrize(
("endpoint", "data"),
[
("/api/cloud/forgot_password", {"email": "fake@example.com"}),
("/api/cloud/google_actions/sync", None),
("/api/cloud/login", {"email": "fake@example.com", "password": "secret"}),
("/api/cloud/logout", None),
("/api/cloud/register", {"email": "fake@example.com", "password": "secret"}),
("/api/cloud/resend_confirm", {"email": "fake@example.com"}),
],
)
async def test_api_calls_require_admin(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_read_only_access_token: str,
endpoint: str,
data: dict[str, Any] | None,
) -> None:
"""Test cloud APIs endpoints do not work as a normal user."""
client = await hass_client(hass_read_only_access_token)
resp = await client.post(endpoint, json=data)
assert resp.status == HTTPStatus.UNAUTHORIZED