Files
supervisor/supervisor/api/utils.py
dependabot[bot] dafc2cfec2 Bump pyupgrade from 2.26.0 to 2.26.0.post1 (#3131)
* Bump pyupgrade from 2.26.0 to 2.26.0.post1

Bumps [pyupgrade](https://github.com/asottile/pyupgrade) from 2.26.0 to 2.26.0.post1.
- [Release notes](https://github.com/asottile/pyupgrade/releases)
- [Commits](https://github.com/asottile/pyupgrade/commits)

---
updated-dependencies:
- dependency-name: pyupgrade
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update .pre-commit-config.yaml

* Fixes

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Pascal Vizeli <pvizeli@syshack.ch>
2021-09-21 18:22:56 +02:00

165 lines
4.8 KiB
Python

"""Init file for Supervisor util for RESTful API."""
import json
from typing import Any, Optional
from aiohttp import web
from aiohttp.hdrs import AUTHORIZATION
from aiohttp.web_exceptions import HTTPUnauthorized
from aiohttp.web_request import Request
import voluptuous as vol
from voluptuous.humanize import humanize_error
from ..const import (
CONTENT_TYPE_BINARY,
HEADER_TOKEN,
HEADER_TOKEN_OLD,
JSON_DATA,
JSON_MESSAGE,
JSON_RESULT,
REQUEST_FROM,
RESULT_ERROR,
RESULT_OK,
)
from ..coresys import CoreSys
from ..exceptions import APIError, APIForbidden, DockerAPIError, HassioError
from ..utils import check_exception_chain, get_message_from_exception_chain
from ..utils.json import JSONEncoder
from ..utils.log_format import format_message
def excract_supervisor_token(request: web.Request) -> Optional[str]:
"""Extract Supervisor token from request."""
supervisor_token = request.headers.get(HEADER_TOKEN)
if supervisor_token:
return supervisor_token
# Remove with old Supervisor fallback
supervisor_token = request.headers.get(HEADER_TOKEN_OLD)
if supervisor_token:
return supervisor_token
# API access only
supervisor_token = request.headers.get(AUTHORIZATION)
if supervisor_token:
return supervisor_token.split(" ")[-1]
return None
def json_loads(data: Any) -> dict[str, Any]:
"""Extract json from string with support for '' and None."""
if not data:
return {}
try:
return json.loads(data)
except json.JSONDecodeError as err:
raise APIError("Invalid json") from err
def api_process(method):
"""Wrap function with true/false calls to rest api."""
async def wrap_api(api, *args, **kwargs):
"""Return API information."""
try:
answer = await method(api, *args, **kwargs)
except (APIError, APIForbidden, HassioError) as err:
return api_return_error(error=err)
if isinstance(answer, (dict, list)):
return api_return_ok(data=answer)
if isinstance(answer, web.Response):
return answer
elif isinstance(answer, bool) and not answer:
return api_return_error()
return api_return_ok()
return wrap_api
def require_home_assistant(method):
"""Ensure that the request comes from Home Assistant."""
async def wrap_api(api, *args, **kwargs):
"""Return API information."""
coresys: CoreSys = api.coresys
request: Request = args[0]
if request[REQUEST_FROM] != coresys.homeassistant:
raise HTTPUnauthorized()
return await method(api, *args, **kwargs)
return wrap_api
def api_process_raw(content):
"""Wrap content_type into function."""
def wrap_method(method):
"""Wrap function with raw output to rest api."""
async def wrap_api(api, *args, **kwargs):
"""Return api information."""
try:
msg_data = await method(api, *args, **kwargs)
msg_type = content
except (APIError, APIForbidden) as err:
msg_data = str(err).encode()
msg_type = CONTENT_TYPE_BINARY
except HassioError:
msg_data = b""
msg_type = CONTENT_TYPE_BINARY
return web.Response(body=msg_data, content_type=msg_type)
return wrap_api
return wrap_method
def api_return_error(
error: Optional[Exception] = None, message: Optional[str] = None
) -> web.Response:
"""Return an API error message."""
if error and not message:
message = get_message_from_exception_chain(error)
if check_exception_chain(error, DockerAPIError):
message = format_message(message)
return web.json_response(
{
JSON_RESULT: RESULT_ERROR,
JSON_MESSAGE: message or "Unknown error, see supervisor",
},
status=400,
dumps=lambda x: json.dumps(x, cls=JSONEncoder),
)
def api_return_ok(data: Optional[dict[str, Any]] = None) -> web.Response:
"""Return an API ok answer."""
return web.json_response(
{JSON_RESULT: RESULT_OK, JSON_DATA: data or {}},
dumps=lambda x: json.dumps(x, cls=JSONEncoder),
)
async def api_validate(
schema: vol.Schema, request: web.Request, origin: Optional[list[str]] = None
) -> dict[str, Any]:
"""Validate request data with schema."""
data: dict[str, Any] = await request.json(loads=json_loads)
try:
data_validated = schema(data)
except vol.Invalid as ex:
raise APIError(humanize_error(data, ex)) from None
if not origin:
return data_validated
for origin_value in origin:
if origin_value not in data_validated:
continue
data_validated[origin_value] = data[origin_value]
return data_validated