mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Extract data validator to own file and add tests (#12401)
This commit is contained in:
parent
416f64fc70
commit
78c44180f4
@ -6,8 +6,9 @@ import logging
|
|||||||
import async_timeout
|
import async_timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.http import (
|
from homeassistant.components.http import HomeAssistantView
|
||||||
HomeAssistantView, RequestDataValidator)
|
from homeassistant.components.http.data_validator import (
|
||||||
|
RequestDataValidator)
|
||||||
|
|
||||||
from . import auth_api
|
from . import auth_api
|
||||||
from .const import DOMAIN, REQUEST_TIMEOUT
|
from .const import DOMAIN, REQUEST_TIMEOUT
|
||||||
|
@ -12,6 +12,8 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant import core
|
from homeassistant import core
|
||||||
from homeassistant.components import http
|
from homeassistant.components import http
|
||||||
|
from homeassistant.components.http.data_validator import (
|
||||||
|
RequestDataValidator)
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
|
|
||||||
@ -148,7 +150,7 @@ class ConversationProcessView(http.HomeAssistantView):
|
|||||||
url = '/api/conversation/process'
|
url = '/api/conversation/process'
|
||||||
name = "api:conversation:process"
|
name = "api:conversation:process"
|
||||||
|
|
||||||
@http.RequestDataValidator(vol.Schema({
|
@RequestDataValidator(vol.Schema({
|
||||||
vol.Required('text'): str,
|
vol.Required('text'): str,
|
||||||
}))
|
}))
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at
|
|||||||
https://home-assistant.io/components/http/
|
https://home-assistant.io/components/http/
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import wraps
|
|
||||||
from ipaddress import ip_network
|
from ipaddress import ip_network
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -415,14 +414,13 @@ def request_handler_factory(view, handler):
|
|||||||
if not request.app['hass'].is_running:
|
if not request.app['hass'].is_running:
|
||||||
return web.Response(status=503)
|
return web.Response(status=503)
|
||||||
|
|
||||||
remote_addr = get_real_ip(request)
|
|
||||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||||
|
|
||||||
if view.requires_auth and not authenticated:
|
if view.requires_auth and not authenticated:
|
||||||
raise HTTPUnauthorized()
|
raise HTTPUnauthorized()
|
||||||
|
|
||||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||||
request.path, remote_addr, authenticated)
|
request.path, get_real_ip(request), authenticated)
|
||||||
|
|
||||||
result = handler(request, **request.match_info)
|
result = handler(request, **request.match_info)
|
||||||
|
|
||||||
@ -449,41 +447,3 @@ def request_handler_factory(view, handler):
|
|||||||
return web.Response(body=result, status=status_code)
|
return web.Response(body=result, status=status_code)
|
||||||
|
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
||||||
class RequestDataValidator:
|
|
||||||
"""Decorator that will validate the incoming data.
|
|
||||||
|
|
||||||
Takes in a voluptuous schema and adds 'post_data' as
|
|
||||||
keyword argument to the function call.
|
|
||||||
|
|
||||||
Will return a 400 if no JSON provided or doesn't match schema.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, schema):
|
|
||||||
"""Initialize the decorator."""
|
|
||||||
self._schema = schema
|
|
||||||
|
|
||||||
def __call__(self, method):
|
|
||||||
"""Decorate a function."""
|
|
||||||
@asyncio.coroutine
|
|
||||||
@wraps(method)
|
|
||||||
def wrapper(view, request, *args, **kwargs):
|
|
||||||
"""Wrap a request handler with data validation."""
|
|
||||||
try:
|
|
||||||
data = yield from request.json()
|
|
||||||
except ValueError:
|
|
||||||
_LOGGER.error('Invalid JSON received.')
|
|
||||||
return view.json_message('Invalid JSON.', 400)
|
|
||||||
|
|
||||||
try:
|
|
||||||
kwargs['data'] = self._schema(data)
|
|
||||||
except vol.Invalid as err:
|
|
||||||
_LOGGER.error('Data does not match schema: %s', err)
|
|
||||||
return view.json_message(
|
|
||||||
'Message format incorrect: {}'.format(err), 400)
|
|
||||||
|
|
||||||
result = yield from method(view, request, *args, **kwargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
51
homeassistant/components/http/data_validator.py
Normal file
51
homeassistant/components/http/data_validator.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Decorator for view methods to help with data validation."""
|
||||||
|
import asyncio
|
||||||
|
from functools import wraps
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestDataValidator:
|
||||||
|
"""Decorator that will validate the incoming data.
|
||||||
|
|
||||||
|
Takes in a voluptuous schema and adds 'post_data' as
|
||||||
|
keyword argument to the function call.
|
||||||
|
|
||||||
|
Will return a 400 if no JSON provided or doesn't match schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, schema, allow_empty=False):
|
||||||
|
"""Initialize the decorator."""
|
||||||
|
self._schema = schema
|
||||||
|
self._allow_empty = allow_empty
|
||||||
|
|
||||||
|
def __call__(self, method):
|
||||||
|
"""Decorate a function."""
|
||||||
|
@asyncio.coroutine
|
||||||
|
@wraps(method)
|
||||||
|
def wrapper(view, request, *args, **kwargs):
|
||||||
|
"""Wrap a request handler with data validation."""
|
||||||
|
data = None
|
||||||
|
try:
|
||||||
|
data = yield from request.json()
|
||||||
|
except ValueError:
|
||||||
|
if not self._allow_empty or \
|
||||||
|
(yield from request.content.read()) != b'':
|
||||||
|
_LOGGER.error('Invalid JSON received.')
|
||||||
|
return view.json_message('Invalid JSON.', 400)
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs['data'] = self._schema(data)
|
||||||
|
except vol.Invalid as err:
|
||||||
|
_LOGGER.error('Data does not match schema: %s', err)
|
||||||
|
return view.json_message(
|
||||||
|
'Message format incorrect: {}'.format(err), 400)
|
||||||
|
|
||||||
|
result = yield from method(view, request, *args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
@ -10,7 +10,7 @@ def get_real_ip(request):
|
|||||||
if KEY_REAL_IP in request:
|
if KEY_REAL_IP in request:
|
||||||
return request[KEY_REAL_IP]
|
return request[KEY_REAL_IP]
|
||||||
|
|
||||||
if (request.app[KEY_USE_X_FORWARDED_FOR] and
|
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
|
||||||
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
|
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
|
||||||
request[KEY_REAL_IP] = ip_address(
|
request[KEY_REAL_IP] = ip_address(
|
||||||
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
|
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
|
||||||
|
@ -10,6 +10,8 @@ import voluptuous as vol
|
|||||||
from homeassistant.const import HTTP_NOT_FOUND, HTTP_BAD_REQUEST
|
from homeassistant.const import HTTP_NOT_FOUND, HTTP_BAD_REQUEST
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.components import http
|
from homeassistant.components import http
|
||||||
|
from homeassistant.components.http.data_validator import (
|
||||||
|
RequestDataValidator)
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
|
||||||
@ -199,7 +201,7 @@ class CreateShoppingListItemView(http.HomeAssistantView):
|
|||||||
url = '/api/shopping_list/item'
|
url = '/api/shopping_list/item'
|
||||||
name = "api:shopping_list:item"
|
name = "api:shopping_list:item"
|
||||||
|
|
||||||
@http.RequestDataValidator(vol.Schema({
|
@RequestDataValidator(vol.Schema({
|
||||||
vol.Required('name'): str,
|
vol.Required('name'): str,
|
||||||
}))
|
}))
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
77
tests/components/http/test_data_validator.py
Normal file
77
tests/components/http/test_data_validator.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Test data validator decorator."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def get_client(test_client, validator):
|
||||||
|
"""Generate a client that hits a view decorated with validator."""
|
||||||
|
app = web.Application()
|
||||||
|
app['hass'] = Mock(is_running=True)
|
||||||
|
|
||||||
|
class TestView(HomeAssistantView):
|
||||||
|
url = '/'
|
||||||
|
name = 'test'
|
||||||
|
requires_auth = False
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
@validator
|
||||||
|
def post(self, request, data):
|
||||||
|
"""Test method."""
|
||||||
|
return b''
|
||||||
|
|
||||||
|
TestView().register(app.router)
|
||||||
|
client = yield from test_client(app)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_validator(test_client):
|
||||||
|
"""Test the validator."""
|
||||||
|
client = yield from get_client(
|
||||||
|
test_client, RequestDataValidator(vol.Schema({
|
||||||
|
vol.Required('test'): str
|
||||||
|
})))
|
||||||
|
|
||||||
|
resp = yield from client.post('/', json={
|
||||||
|
'test': 'bla'
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
resp = yield from client.post('/', json={
|
||||||
|
'test': 100
|
||||||
|
})
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
resp = yield from client.post('/')
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_validator_allow_empty(test_client):
|
||||||
|
"""Test the validator with empty data."""
|
||||||
|
client = yield from get_client(
|
||||||
|
test_client, RequestDataValidator(vol.Schema({
|
||||||
|
# Although we allow empty, our schema should still be able
|
||||||
|
# to validate an empty dict.
|
||||||
|
vol.Optional('test'): str
|
||||||
|
}), allow_empty=True))
|
||||||
|
|
||||||
|
resp = yield from client.post('/', json={
|
||||||
|
'test': 'bla'
|
||||||
|
})
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
resp = yield from client.post('/', json={
|
||||||
|
'test': 100
|
||||||
|
})
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
resp = yield from client.post('/')
|
||||||
|
assert resp.status == 200
|
Loading…
x
Reference in New Issue
Block a user