Adjust require_admin decorator typing (#108306)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Marc Mueller 2024-01-19 01:12:14 +01:00 committed by GitHub
parent a21d5b5858
commit 25b7bb4a4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 14 deletions

View File

@ -123,7 +123,7 @@ class APIEventStream(HomeAssistantView):
name = "api:stream"
@require_admin
async def get(self, request):
async def get(self, request: web.Request) -> web.StreamResponse:
"""Provide a streaming interface for the event bus."""
hass: HomeAssistant = request.app["hass"]
stop_obj = object()
@ -464,7 +464,7 @@ class APIErrorLog(HomeAssistantView):
name = "api:error_log"
@require_admin
async def get(self, request):
async def get(self, request: web.Request) -> web.FileResponse:
"""Retrieve API error log."""
hass: HomeAssistant = request.app["hass"]
return web.FileResponse(hass.data[DATA_LOGGING])

View File

@ -5,16 +5,18 @@ from collections.abc import Callable, Coroutine
from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar, overload
from aiohttp.web import Request, Response
from aiohttp.web import Request, Response, StreamResponse
from homeassistant.auth.models import User
from homeassistant.exceptions import Unauthorized
from .view import HomeAssistantView
_HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView)
_ResponseT = TypeVar("_ResponseT", bound=Response | StreamResponse)
_P = ParamSpec("_P")
_FuncType = Callable[
Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, Response]
Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, _ResponseT]
]
@ -23,30 +25,36 @@ def require_admin(
_func: None = None,
*,
error: Unauthorized | None = None,
) -> Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]:
) -> Callable[
[_FuncType[_HomeAssistantViewT, _P, _ResponseT]],
_FuncType[_HomeAssistantViewT, _P, _ResponseT],
]:
...
@overload
def require_admin(
_func: _FuncType[_HomeAssistantViewT, _P],
) -> _FuncType[_HomeAssistantViewT, _P]:
_func: _FuncType[_HomeAssistantViewT, _P, _ResponseT],
) -> _FuncType[_HomeAssistantViewT, _P, _ResponseT]:
...
def require_admin(
_func: _FuncType[_HomeAssistantViewT, _P] | None = None,
_func: _FuncType[_HomeAssistantViewT, _P, _ResponseT] | None = None,
*,
error: Unauthorized | None = None,
) -> (
Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]
| _FuncType[_HomeAssistantViewT, _P]
Callable[
[_FuncType[_HomeAssistantViewT, _P, _ResponseT]],
_FuncType[_HomeAssistantViewT, _P, _ResponseT],
]
| _FuncType[_HomeAssistantViewT, _P, _ResponseT]
):
"""Home Assistant API decorator to require user to be an admin."""
def decorator_require_admin(
func: _FuncType[_HomeAssistantViewT, _P],
) -> _FuncType[_HomeAssistantViewT, _P]:
func: _FuncType[_HomeAssistantViewT, _P, _ResponseT],
) -> _FuncType[_HomeAssistantViewT, _P, _ResponseT]:
"""Wrap the provided with_admin function."""
@wraps(func)
@ -55,9 +63,10 @@ def require_admin(
request: Request,
*args: _P.args,
**kwargs: _P.kwargs,
) -> Response:
) -> _ResponseT:
"""Check admin and call function."""
if not request["hass_user"].is_admin:
user: User = request["hass_user"]
if not user.is_admin:
raise error or Unauthorized()
return await func(self, request, *args, **kwargs)