From 9e95e8671edea1d98ed3964c79f0f699cbd77cfa Mon Sep 17 00:00:00 2001 From: pvizeli Date: Fri, 7 Apr 2017 11:07:23 +0200 Subject: [PATCH] Add options validation --- hassio/api/homeassistant.py | 10 ++++++++-- hassio/api/host.py | 10 ++++++++-- hassio/api/supervisor.py | 16 +++++++++++++--- hassio/api/util.py | 13 +++++++++++++ setup.py | 1 + 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/hassio/api/homeassistant.py b/hassio/api/homeassistant.py index f4ad3841f..cd55a2453 100644 --- a/hassio/api/homeassistant.py +++ b/hassio/api/homeassistant.py @@ -2,11 +2,17 @@ import asyncio import logging -from .util import api_process, json_loads +import voluptuous as vol + +from .util import api_process, api_validate from ..const import ATTR_VERSION, ATTR_CURRENT _LOGGER = logging.getLogger(__name__) +SCHEMA_VERSION = vol.Schema({ + vol.Optional(ATTR_VERSION): vol.Coerce(str), +}) + class APIHomeAssistant(object): """Handle rest api for homeassistant functions.""" @@ -30,7 +36,7 @@ class APIHomeAssistant(object): @api_process async def update(self, request): """Update host OS.""" - body = await request.json(loads=json_loads) + body = await api_validate(SCHEMA_VERSION, request) version = body.get(ATTR_VERSION, self.config.current_homeassistant) if self.dock_hass.in_progress: diff --git a/hassio/api/host.py b/hassio/api/host.py index 8dc3a2217..23f0fb453 100644 --- a/hassio/api/host.py +++ b/hassio/api/host.py @@ -1,13 +1,19 @@ """Init file for HassIO host rest api.""" import logging -from .util import api_process_hostcontroll, api_process, json_loads +import voluptuous as vol + +from .util import api_process_hostcontroll, api_process, api_validate from ..const import ATTR_VERSION _LOGGER = logging.getLogger(__name__) UNKNOWN = 'unknown' +SCHEMA_VERSION = vol.Schema({ + vol.Optional(ATTR_VERSION): vol.Coerce(str), +}) + class APIHost(object): """Handle rest api for host functions.""" @@ -46,7 +52,7 @@ class APIHost(object): @api_process_hostcontroll async def update(self, request): """Update host OS.""" - body = await request.json(loads=json_loads) + body = await api_validate(SCHEMA_VERSION, request) version = body.get(ATTR_VERSION) if version == self.host_controll.version: diff --git a/hassio/api/supervisor.py b/hassio/api/supervisor.py index 7c57d12ed..94de47920 100644 --- a/hassio/api/supervisor.py +++ b/hassio/api/supervisor.py @@ -1,11 +1,21 @@ """Init file for HassIO supervisor rest api.""" import logging -from .util import api_process, api_process_hostcontroll, json_loads +import voluptuous as vol + +from .util import api_process, api_process_hostcontroll, api_validate from ..const import ATTR_VERSION, ATTR_CURRENT, ATTR_BETA, HASSIO_VERSION _LOGGER = logging.getLogger(__name__) +SCHEMA_OPTIONS = vol.Schema({ + vol.Optional(ATTR_BETA): vol.Boolean(), +}) + +SCHEMA_VERSION = vol.Schema({ + vol.Optional(ATTR_VERSION): vol.Coerce(str), +}) + class APISupervisor(object): """Handle rest api for supervisor functions.""" @@ -35,7 +45,7 @@ class APISupervisor(object): @api_process async def options(self, request): """Set supervisor options.""" - body = await request.json(loads=json_loads) + body = await api_validate(SCHEMA_OPTIONS, request) if ATTR_BETA in body: self.config.upstream_beta = body[ATTR_BETA] @@ -45,7 +55,7 @@ class APISupervisor(object): @api_process_hostcontroll async def update(self, request): """Update host OS.""" - body = await request.json(loads=json_loads) + body = await api_validate(SCHEMA_VERSION, request) version = body.get(ATTR_VERSION, self.config.current_hassio) if version == HASSIO_VERSION: diff --git a/hassio/api/util.py b/hassio/api/util.py index 50bf94ef2..fd75d5ec2 100644 --- a/hassio/api/util.py +++ b/hassio/api/util.py @@ -4,6 +4,8 @@ import logging from aiohttp import web from aiohttp.web_exceptions import HTTPServiceUnavailable +import voluptuous as vol +from voluptuous.humanize import humanize_error from ..const import ( JSON_RESULT, JSON_DATA, JSON_MESSAGE, RESULT_OK, RESULT_ERROR) @@ -74,3 +76,14 @@ def api_return_ok(data=None): JSON_RESULT: RESULT_OK, JSON_DATA: data or {}, }) + + +async def api_validate(schema, request): + """Validate request data with schema.""" + data = await request.json(loads=json_loads) + try: + schema(data) + except vol.Invalid as ex: + raise RuntimeError(humanize_error(data, ex)) from None + + return data diff --git a/setup.py b/setup.py index 7927bbf33..31a39092c 100644 --- a/setup.py +++ b/setup.py @@ -36,5 +36,6 @@ setup( 'aiohttp', 'docker', 'colorlog', + 'voluptuous', ] )