diff --git a/homeassistant/components/api/__init__.py b/homeassistant/components/api/__init__.py index a9b7fc08273..8a5e1f0b0e0 100644 --- a/homeassistant/components/api/__init__.py +++ b/homeassistant/components/api/__init__.py @@ -123,7 +123,7 @@ class APIEventStream(HomeAssistantView): name = "api:stream" @require_admin - async def get(self, request): + async def get(self, request: web.Request) -> web.StreamResponse: """Provide a streaming interface for the event bus.""" hass: HomeAssistant = request.app["hass"] stop_obj = object() @@ -464,7 +464,7 @@ class APIErrorLog(HomeAssistantView): name = "api:error_log" @require_admin - async def get(self, request): + async def get(self, request: web.Request) -> web.FileResponse: """Retrieve API error log.""" hass: HomeAssistant = request.app["hass"] return web.FileResponse(hass.data[DATA_LOGGING]) diff --git a/homeassistant/components/http/decorators.py b/homeassistant/components/http/decorators.py index 4d8ac5c2df5..b2e8e535fd2 100644 --- a/homeassistant/components/http/decorators.py +++ b/homeassistant/components/http/decorators.py @@ -5,16 +5,18 @@ from collections.abc import Callable, Coroutine from functools import wraps from typing import Any, Concatenate, ParamSpec, TypeVar, overload -from aiohttp.web import Request, Response +from aiohttp.web import Request, Response, StreamResponse +from homeassistant.auth.models import User from homeassistant.exceptions import Unauthorized from .view import HomeAssistantView _HomeAssistantViewT = TypeVar("_HomeAssistantViewT", bound=HomeAssistantView) +_ResponseT = TypeVar("_ResponseT", bound=Response | StreamResponse) _P = ParamSpec("_P") _FuncType = Callable[ - Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, Response] + Concatenate[_HomeAssistantViewT, Request, _P], Coroutine[Any, Any, _ResponseT] ] @@ -23,30 +25,36 @@ def require_admin( _func: None = None, *, error: Unauthorized | None = None, -) -> Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]]: +) -> Callable[ + [_FuncType[_HomeAssistantViewT, _P, _ResponseT]], + _FuncType[_HomeAssistantViewT, _P, _ResponseT], +]: ... @overload def require_admin( - _func: _FuncType[_HomeAssistantViewT, _P], -) -> _FuncType[_HomeAssistantViewT, _P]: + _func: _FuncType[_HomeAssistantViewT, _P, _ResponseT], +) -> _FuncType[_HomeAssistantViewT, _P, _ResponseT]: ... def require_admin( - _func: _FuncType[_HomeAssistantViewT, _P] | None = None, + _func: _FuncType[_HomeAssistantViewT, _P, _ResponseT] | None = None, *, error: Unauthorized | None = None, ) -> ( - Callable[[_FuncType[_HomeAssistantViewT, _P]], _FuncType[_HomeAssistantViewT, _P]] - | _FuncType[_HomeAssistantViewT, _P] + Callable[ + [_FuncType[_HomeAssistantViewT, _P, _ResponseT]], + _FuncType[_HomeAssistantViewT, _P, _ResponseT], + ] + | _FuncType[_HomeAssistantViewT, _P, _ResponseT] ): """Home Assistant API decorator to require user to be an admin.""" def decorator_require_admin( - func: _FuncType[_HomeAssistantViewT, _P], - ) -> _FuncType[_HomeAssistantViewT, _P]: + func: _FuncType[_HomeAssistantViewT, _P, _ResponseT], + ) -> _FuncType[_HomeAssistantViewT, _P, _ResponseT]: """Wrap the provided with_admin function.""" @wraps(func) @@ -55,9 +63,10 @@ def require_admin( request: Request, *args: _P.args, **kwargs: _P.kwargs, - ) -> Response: + ) -> _ResponseT: """Check admin and call function.""" - if not request["hass_user"].is_admin: + user: User = request["hass_user"] + if not user.is_admin: raise error or Unauthorized() return await func(self, request, *args, **kwargs)