mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Add is_admin checks to cloud APIs (#97804)
This commit is contained in:
parent
3859d2e2a6
commit
5e020ea354
@ -24,7 +24,7 @@ from homeassistant.components.alexa import (
|
|||||||
)
|
)
|
||||||
from homeassistant.components.google_assistant import helpers as google_helpers
|
from homeassistant.components.google_assistant import helpers as google_helpers
|
||||||
from homeassistant.components.homeassistant import exposed_entities
|
from homeassistant.components.homeassistant import exposed_entities
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView, require_admin
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
|
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -128,7 +128,6 @@ def _handle_cloud_errors(
|
|||||||
try:
|
try:
|
||||||
result = await handler(view, request, *args, **kwargs)
|
result = await handler(view, request, *args, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
status, msg = _process_cloud_exception(err, request.path)
|
status, msg = _process_cloud_exception(err, request.path)
|
||||||
return view.json_message(
|
return view.json_message(
|
||||||
@ -188,6 +187,7 @@ class GoogleActionsSyncView(HomeAssistantView):
|
|||||||
url = "/api/cloud/google_actions/sync"
|
url = "/api/cloud/google_actions/sync"
|
||||||
name = "api:cloud:google_actions/sync"
|
name = "api:cloud:google_actions/sync"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
async def post(self, request: web.Request) -> web.Response:
|
async def post(self, request: web.Request) -> web.Response:
|
||||||
"""Trigger a Google Actions sync."""
|
"""Trigger a Google Actions sync."""
|
||||||
@ -204,6 +204,7 @@ class CloudLoginView(HomeAssistantView):
|
|||||||
url = "/api/cloud/login"
|
url = "/api/cloud/login"
|
||||||
name = "api:cloud:login"
|
name = "api:cloud:login"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
@RequestDataValidator(
|
@RequestDataValidator(
|
||||||
vol.Schema({vol.Required("email"): str, vol.Required("password"): str})
|
vol.Schema({vol.Required("email"): str, vol.Required("password"): str})
|
||||||
@ -244,6 +245,7 @@ class CloudLogoutView(HomeAssistantView):
|
|||||||
url = "/api/cloud/logout"
|
url = "/api/cloud/logout"
|
||||||
name = "api:cloud:logout"
|
name = "api:cloud:logout"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
async def post(self, request: web.Request) -> web.Response:
|
async def post(self, request: web.Request) -> web.Response:
|
||||||
"""Handle logout request."""
|
"""Handle logout request."""
|
||||||
@ -262,6 +264,7 @@ class CloudRegisterView(HomeAssistantView):
|
|||||||
url = "/api/cloud/register"
|
url = "/api/cloud/register"
|
||||||
name = "api:cloud:register"
|
name = "api:cloud:register"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
@RequestDataValidator(
|
@RequestDataValidator(
|
||||||
vol.Schema(
|
vol.Schema(
|
||||||
@ -305,6 +308,7 @@ class CloudResendConfirmView(HomeAssistantView):
|
|||||||
url = "/api/cloud/resend_confirm"
|
url = "/api/cloud/resend_confirm"
|
||||||
name = "api:cloud:resend_confirm"
|
name = "api:cloud:resend_confirm"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
@RequestDataValidator(vol.Schema({vol.Required("email"): str}))
|
@RequestDataValidator(vol.Schema({vol.Required("email"): str}))
|
||||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||||
@ -324,6 +328,7 @@ class CloudForgotPasswordView(HomeAssistantView):
|
|||||||
url = "/api/cloud/forgot_password"
|
url = "/api/cloud/forgot_password"
|
||||||
name = "api:cloud:forgot_password"
|
name = "api:cloud:forgot_password"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
@RequestDataValidator(vol.Schema({vol.Required("email"): str}))
|
@RequestDataValidator(vol.Schema({vol.Required("email"): str}))
|
||||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||||
|
@ -52,6 +52,7 @@ from .const import ( # noqa: F401
|
|||||||
KEY_HASS_USER,
|
KEY_HASS_USER,
|
||||||
)
|
)
|
||||||
from .cors import setup_cors
|
from .cors import setup_cors
|
||||||
|
from .decorators import require_admin # noqa: F401
|
||||||
from .forwarded import async_setup_forwarded
|
from .forwarded import async_setup_forwarded
|
||||||
from .headers import setup_headers
|
from .headers import setup_headers
|
||||||
from .request_context import current_request, setup_request_context
|
from .request_context import current_request, setup_request_context
|
||||||
|
31
homeassistant/components/http/decorators.py
Normal file
31
homeassistant/components/http/decorators.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""Decorators for the Home Assistant API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from aiohttp.web import Request, Response
|
||||||
|
|
||||||
|
from homeassistant.exceptions import Unauthorized
|
||||||
|
|
||||||
|
from .view import HomeAssistantView
|
||||||
|
|
||||||
|
_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView)
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
|
def require_admin(
|
||||||
|
func: Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]
|
||||||
|
) -> Callable[Concatenate[_HomeAssistantViewT, Request, _P], Awaitable[Response]]:
|
||||||
|
"""Home Assistant API decorator to require user to be an admin."""
|
||||||
|
|
||||||
|
async def with_admin(
|
||||||
|
self: _HomeAssistantViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
|
||||||
|
) -> Response:
|
||||||
|
"""Check admin and call function."""
|
||||||
|
if not request["hass_user"].is_admin:
|
||||||
|
raise Unauthorized()
|
||||||
|
|
||||||
|
return await func(self, request, *args, **kwargs)
|
||||||
|
|
||||||
|
return with_admin
|
@ -1,6 +1,7 @@
|
|||||||
"""Tests for the HTTP API for the cloud component."""
|
"""Tests for the HTTP API for the cloud component."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -24,7 +25,7 @@ from . import mock_cloud, mock_cloud_prefs
|
|||||||
|
|
||||||
from tests.components.google_assistant import MockConfig
|
from tests.components.google_assistant import MockConfig
|
||||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/payments/subscription_info"
|
SUBSCRIPTION_INFO_URL = "https://api-test.hass.io/payments/subscription_info"
|
||||||
|
|
||||||
@ -1207,3 +1208,28 @@ async def test_tts_info(
|
|||||||
|
|
||||||
assert response["success"]
|
assert response["success"]
|
||||||
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}
|
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("endpoint", "data"),
|
||||||
|
[
|
||||||
|
("/api/cloud/forgot_password", {"email": "fake@example.com"}),
|
||||||
|
("/api/cloud/google_actions/sync", None),
|
||||||
|
("/api/cloud/login", {"email": "fake@example.com", "password": "secret"}),
|
||||||
|
("/api/cloud/logout", None),
|
||||||
|
("/api/cloud/register", {"email": "fake@example.com", "password": "secret"}),
|
||||||
|
("/api/cloud/resend_confirm", {"email": "fake@example.com"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_api_calls_require_admin(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
hass_read_only_access_token: str,
|
||||||
|
endpoint: str,
|
||||||
|
data: dict[str, Any] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Test cloud APIs endpoints do not work as a normal user."""
|
||||||
|
client = await hass_client(hass_read_only_access_token)
|
||||||
|
resp = await client.post(endpoint, json=data)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
Loading…
x
Reference in New Issue
Block a user