From 5aa0158761408bc0dd941eab7d688189f4f009af Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 14 May 2016 21:18:46 -0700 Subject: [PATCH] Add url validators --- homeassistant/components/api.py | 2 +- homeassistant/components/camera/__init__.py | 4 +- homeassistant/components/history.py | 4 +- homeassistant/components/http.py | 111 +++++++++++++++----- homeassistant/components/logbook.py | 2 +- 5 files changed, 92 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index 6e1f3c0fe18..351bdf5dcd2 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -214,7 +214,7 @@ class APIStatesView(HomeAssistantView): class APIEntityStateView(HomeAssistantView): """View to handle EntityState requests.""" - url = "/api/states/" + url = "/api/states/" name = "api:entity-state" def get(self, request, entity_id): diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 1806b5e66c6..1a6fa2cb956 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -147,7 +147,7 @@ class CameraView(HomeAssistantView): class CameraImageView(CameraView): """Camera view to serve an image.""" - url = "/api/camera_proxy/" + url = "/api/camera_proxy/" name = "api:camera:image" def get(self, request, entity_id): @@ -168,7 +168,7 @@ class CameraImageView(CameraView): class CameraMjpegStream(CameraView): """Camera View to serve an MJPEG stream.""" - url = "/api/camera_proxy_stream/" + url = "/api/camera_proxy_stream/" name = "api:camera:stream" def get(self, request, entity_id): diff --git a/homeassistant/components/history.py b/homeassistant/components/history.py index 4e1348e1fa9..7ede33f9e15 100644 --- a/homeassistant/components/history.py +++ b/homeassistant/components/history.py @@ -165,7 +165,7 @@ def setup(hass, config): class Last5StatesView(HomeAssistantView): """Handle last 5 state view requests.""" - url = '/api/history/entity//recent_states' + url = '/api/history/entity//recent_states' name = 'api:history:entity-recent-states' def get(self, request, entity_id): @@ -178,7 +178,7 @@ class HistoryPeriodView(HomeAssistantView): url = '/api/history/period' name = 'api:history:entity-recent-states' - extra_urls = ['/api/history/period/'] + extra_urls = ['/api/history/period/'] def get(self, request, date=None): """Return history over a period of time.""" diff --git a/homeassistant/components/http.py b/homeassistant/components/http.py index eab6d4a33eb..c6cf5839e7a 100644 --- a/homeassistant/components/http.py +++ b/homeassistant/components/http.py @@ -10,6 +10,8 @@ import homeassistant.core as ha import homeassistant.remote as rem from homeassistant import util from homeassistant.const import SERVER_PORT, HTTP_HEADER_HA_AUTH +from homeassistant.helpers.entity import valid_entity_id, split_entity_id +import homeassistant.util.dt as dt_util DOMAIN = "http" REQUIREMENTS = ("eventlet==0.18.4", "static3==0.6.1", "Werkzeug==0.11.5",) @@ -77,6 +79,85 @@ def setup(hass, config): # return app(environ, start_response) +def request_class(): + """Generate request class. + + Done in method because of imports.""" + from werkzeug.exceptions import BadRequest + from werkzeug.wrappers import BaseRequest, AcceptMixin + from werkzeug.utils import cached_property + + class Request(BaseRequest, AcceptMixin): + """Base class for incoming requests.""" + + @cached_property + def json(self): + """Get the result of json.loads if possible.""" + if not self.data: + return None + # elif 'json' not in self.environ.get('CONTENT_TYPE', ''): + # raise BadRequest('Not a JSON request') + try: + return json.loads(self.data.decode( + self.charset, self.encoding_errors)) + except (TypeError, ValueError): + raise BadRequest('Unable to read JSON request') + + return Request + + +def routing_map(hass): + """Generate empty routing map with HA validators.""" + from werkzeug.routing import Map, BaseConverter, ValidationError + + class EntityValidator(BaseConverter): + """Validate entity_id in urls.""" + regex = r"(\w+)\.(\w+)" + + def __init__(self, url_map, exist=True, domain=None): + """Initilalize entity validator.""" + super().__init__(url_map) + self._exist = exist + self._domain = domain + + def to_python(self, value): + """Validate entity id.""" + if self._exist and hass.states.get(value) is None: + raise ValidationError() + if self._domain is not None and \ + split_entity_id(value)[0] != self._domain: + raise ValidationError() + + return value + + def to_url(self, value): + """Convert entity_id for a url.""" + return value + + class DateValidator(BaseConverter): + """Validate dates in urls.""" + + regex = r'\d{4}-(0[1-9])|(1[012])-((0[1-9])|([12]\d)|(3[01]))' + + def to_python(self, value): + """Validate and convert date.""" + parsed = dt_util.parse_date(value) + + if value is None: + raise ValidationError() + + return parsed + + def to_url(self, value): + """Convert date to url value.""" + return value.isoformat() + + return Map(converters={ + 'entity': EntityValidator, + 'date': DateValidator, + }) + + class HomeAssistantWSGI(object): """WSGI server for Home Assistant.""" @@ -86,33 +167,13 @@ class HomeAssistantWSGI(object): def __init__(self, hass, development, api_password, ssl_certificate, ssl_key, server_host, server_port): """Initilalize the WSGI Home Assistant server.""" - from werkzeug.exceptions import BadRequest - from werkzeug.wrappers import BaseRequest, AcceptMixin - from werkzeug.routing import Map - from werkzeug.utils import cached_property from werkzeug.wrappers import Response - class Request(BaseRequest, AcceptMixin): - """Base class for incoming requests.""" - - @cached_property - def json(self): - """Get the result of json.loads if possible.""" - if not self.data: - return None - # elif 'json' not in self.environ.get('CONTENT_TYPE', ''): - # raise BadRequest('Not a JSON request') - try: - return json.loads(self.data.decode( - self.charset, self.encoding_errors)) - except (TypeError, ValueError): - raise BadRequest('Unable to read JSON request') - Response.mimetype = 'text/html' # pylint: disable=invalid-name - self.Request = Request - self.url_map = Map() + self.Request = request_class() + self.url_map = routing_map(hass) self.views = {} self.hass = hass self.extra_apps = {} @@ -340,13 +401,13 @@ class HomeAssistantView(object): from werkzeug.exceptions import NotFound if isinstance(fil, str): + if mimetype is None: + mimetype = mimetypes.guess_type(fil)[0] + try: fil = open(fil) except IOError: raise NotFound() - if mimetype is None: - mimetype = mimetypes.guess_type(fil)[0] - return self.Response(wrap_file(request.environ, fil), mimetype=mimetype, direct_passthrough=True) diff --git a/homeassistant/components/logbook.py b/homeassistant/components/logbook.py index 629fb236b3c..6bf5c8207fe 100644 --- a/homeassistant/components/logbook.py +++ b/homeassistant/components/logbook.py @@ -89,7 +89,7 @@ class LogbookView(HomeAssistantView): url = '/api/logbook' name = 'api:logbook' - extra_urls = ['/api/logbook/'] + extra_urls = ['/api/logbook/'] def get(self, request, date=None): """Retrieve logbook entries."""