diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index af966e180eb..f7f327f2f2c 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -6,8 +6,9 @@ import logging import async_timeout import voluptuous as vol -from homeassistant.components.http import ( - HomeAssistantView, RequestDataValidator) +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.data_validator import ( + RequestDataValidator) from . import auth_api from .const import DOMAIN, REQUEST_TIMEOUT diff --git a/homeassistant/components/conversation.py b/homeassistant/components/conversation.py index c1dd89d31cd..9f325f3eb89 100644 --- a/homeassistant/components/conversation.py +++ b/homeassistant/components/conversation.py @@ -12,6 +12,8 @@ import voluptuous as vol from homeassistant import core from homeassistant.components import http +from homeassistant.components.http.data_validator import ( + RequestDataValidator) from homeassistant.helpers import config_validation as cv from homeassistant.helpers import intent @@ -148,7 +150,7 @@ class ConversationProcessView(http.HomeAssistantView): url = '/api/conversation/process' name = "api:conversation:process" - @http.RequestDataValidator(vol.Schema({ + @RequestDataValidator(vol.Schema({ vol.Required('text'): str, })) @asyncio.coroutine diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 33f97395945..22f8c90dfb1 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/http/ """ import asyncio -from functools import wraps from ipaddress import ip_network import json import logging @@ -415,14 +414,13 @@ def request_handler_factory(view, handler): if not request.app['hass'].is_running: return web.Response(status=503) - remote_addr = get_real_ip(request) authenticated = request.get(KEY_AUTHENTICATED, False) if view.requires_auth and not authenticated: raise HTTPUnauthorized() _LOGGER.info('Serving %s to %s (auth: %s)', - request.path, remote_addr, authenticated) + request.path, get_real_ip(request), authenticated) result = handler(request, **request.match_info) @@ -449,41 +447,3 @@ def request_handler_factory(view, handler): return web.Response(body=result, status=status_code) return handle - - -class RequestDataValidator: - """Decorator that will validate the incoming data. - - Takes in a voluptuous schema and adds 'post_data' as - keyword argument to the function call. - - Will return a 400 if no JSON provided or doesn't match schema. - """ - - def __init__(self, schema): - """Initialize the decorator.""" - self._schema = schema - - def __call__(self, method): - """Decorate a function.""" - @asyncio.coroutine - @wraps(method) - def wrapper(view, request, *args, **kwargs): - """Wrap a request handler with data validation.""" - try: - data = yield from request.json() - except ValueError: - _LOGGER.error('Invalid JSON received.') - return view.json_message('Invalid JSON.', 400) - - try: - kwargs['data'] = self._schema(data) - except vol.Invalid as err: - _LOGGER.error('Data does not match schema: %s', err) - return view.json_message( - 'Message format incorrect: {}'.format(err), 400) - - result = yield from method(view, request, *args, **kwargs) - return result - - return wrapper diff --git a/homeassistant/components/http/data_validator.py b/homeassistant/components/http/data_validator.py new file mode 100644 index 00000000000..528c0a598e3 --- /dev/null +++ b/homeassistant/components/http/data_validator.py @@ -0,0 +1,51 @@ +"""Decorator for view methods to help with data validation.""" +import asyncio +from functools import wraps +import logging + +import voluptuous as vol + +_LOGGER = logging.getLogger(__name__) + + +class RequestDataValidator: + """Decorator that will validate the incoming data. + + Takes in a voluptuous schema and adds 'post_data' as + keyword argument to the function call. + + Will return a 400 if no JSON provided or doesn't match schema. + """ + + def __init__(self, schema, allow_empty=False): + """Initialize the decorator.""" + self._schema = schema + self._allow_empty = allow_empty + + def __call__(self, method): + """Decorate a function.""" + @asyncio.coroutine + @wraps(method) + def wrapper(view, request, *args, **kwargs): + """Wrap a request handler with data validation.""" + data = None + try: + data = yield from request.json() + except ValueError: + if not self._allow_empty or \ + (yield from request.content.read()) != b'': + _LOGGER.error('Invalid JSON received.') + return view.json_message('Invalid JSON.', 400) + data = {} + + try: + kwargs['data'] = self._schema(data) + except vol.Invalid as err: + _LOGGER.error('Data does not match schema: %s', err) + return view.json_message( + 'Message format incorrect: {}'.format(err), 400) + + result = yield from method(view, request, *args, **kwargs) + return result + + return wrapper diff --git a/homeassistant/components/http/util.py b/homeassistant/components/http/util.py index 1a5a3d98a22..359c20f4fa1 100644 --- a/homeassistant/components/http/util.py +++ b/homeassistant/components/http/util.py @@ -10,7 +10,7 @@ def get_real_ip(request): if KEY_REAL_IP in request: return request[KEY_REAL_IP] - if (request.app[KEY_USE_X_FORWARDED_FOR] and + if (request.app.get(KEY_USE_X_FORWARDED_FOR) and HTTP_HEADER_X_FORWARDED_FOR in request.headers): request[KEY_REAL_IP] = ip_address( request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0]) diff --git a/homeassistant/components/shopping_list.py b/homeassistant/components/shopping_list.py index 31259325c04..416fdd3f6d0 100644 --- a/homeassistant/components/shopping_list.py +++ b/homeassistant/components/shopping_list.py @@ -10,6 +10,8 @@ import voluptuous as vol from homeassistant.const import HTTP_NOT_FOUND, HTTP_BAD_REQUEST from homeassistant.core import callback from homeassistant.components import http +from homeassistant.components.http.data_validator import ( + RequestDataValidator) from homeassistant.helpers import intent import homeassistant.helpers.config_validation as cv @@ -199,7 +201,7 @@ class CreateShoppingListItemView(http.HomeAssistantView): url = '/api/shopping_list/item' name = "api:shopping_list:item" - @http.RequestDataValidator(vol.Schema({ + @RequestDataValidator(vol.Schema({ vol.Required('name'): str, })) @asyncio.coroutine diff --git a/tests/components/http/test_data_validator.py b/tests/components/http/test_data_validator.py new file mode 100644 index 00000000000..f00be4fc6f9 --- /dev/null +++ b/tests/components/http/test_data_validator.py @@ -0,0 +1,77 @@ +"""Test data validator decorator.""" +import asyncio +from unittest.mock import Mock + +from aiohttp import web +import voluptuous as vol + +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.data_validator import RequestDataValidator + + +@asyncio.coroutine +def get_client(test_client, validator): + """Generate a client that hits a view decorated with validator.""" + app = web.Application() + app['hass'] = Mock(is_running=True) + + class TestView(HomeAssistantView): + url = '/' + name = 'test' + requires_auth = False + + @asyncio.coroutine + @validator + def post(self, request, data): + """Test method.""" + return b'' + + TestView().register(app.router) + client = yield from test_client(app) + return client + + +@asyncio.coroutine +def test_validator(test_client): + """Test the validator.""" + client = yield from get_client( + test_client, RequestDataValidator(vol.Schema({ + vol.Required('test'): str + }))) + + resp = yield from client.post('/', json={ + 'test': 'bla' + }) + assert resp.status == 200 + + resp = yield from client.post('/', json={ + 'test': 100 + }) + assert resp.status == 400 + + resp = yield from client.post('/') + assert resp.status == 400 + + +@asyncio.coroutine +def test_validator_allow_empty(test_client): + """Test the validator with empty data.""" + client = yield from get_client( + test_client, RequestDataValidator(vol.Schema({ + # Although we allow empty, our schema should still be able + # to validate an empty dict. + vol.Optional('test'): str + }), allow_empty=True)) + + resp = yield from client.post('/', json={ + 'test': 'bla' + }) + assert resp.status == 200 + + resp = yield from client.post('/', json={ + 'test': 100 + }) + assert resp.status == 400 + + resp = yield from client.post('/') + assert resp.status == 200