HTTP more robust and increased test coverage

This commit is contained in:
Paulus Schoutsen 2014-11-28 22:27:44 -08:00
parent 014abdba39
commit a4eb975b59
4 changed files with 232 additions and 88 deletions

View File

@ -108,6 +108,8 @@ CONF_SERVER_HOST = "server_host"
CONF_SERVER_PORT = "server_port" CONF_SERVER_PORT = "server_port"
CONF_DEVELOPMENT = "development" CONF_DEVELOPMENT = "development"
DATA_API_PASSWORD = 'api_password'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -118,7 +120,7 @@ def setup(hass, config):
_LOGGER): _LOGGER):
return False return False
api_password = config[DOMAIN]['api_password'] api_password = config[DOMAIN][CONF_API_PASSWORD]
# If no server host is given, accept all incoming requests # If no server host is given, accept all incoming requests
server_host = config[DOMAIN].get(CONF_SERVER_HOST, '0.0.0.0') server_host = config[DOMAIN].get(CONF_SERVER_HOST, '0.0.0.0')
@ -192,7 +194,6 @@ class RequestHandler(SimpleHTTPRequestHandler):
PATHS = [ # debug interface PATHS = [ # debug interface
('GET', URL_ROOT, '_handle_get_root'), ('GET', URL_ROOT, '_handle_get_root'),
('POST', URL_ROOT, '_handle_get_root'),
# /api - for validation purposes # /api - for validation purposes
('GET', rem.URL_API, '_handle_get_api'), ('GET', rem.URL_API, '_handle_get_api'),
@ -228,8 +229,10 @@ class RequestHandler(SimpleHTTPRequestHandler):
('DELETE', rem.URL_API_EVENT_FORWARD, ('DELETE', rem.URL_API_EVENT_FORWARD,
'_handle_delete_api_event_forward'), '_handle_delete_api_event_forward'),
# Statis files # Static files
('GET', re.compile(r'/static/(?P<file>[a-zA-Z\._\-0-9/]+)'), ('GET', re.compile(r'/static/(?P<file>[a-zA-Z\._\-0-9/]+)'),
'_handle_get_static'),
('HEAD', re.compile(r'/static/(?P<file>[a-zA-Z\._\-0-9/]+)'),
'_handle_get_static') '_handle_get_static')
] ]
@ -255,24 +258,22 @@ class RequestHandler(SimpleHTTPRequestHandler):
if content_length: if content_length:
body_content = self.rfile.read(content_length).decode("UTF-8") body_content = self.rfile.read(content_length).decode("UTF-8")
if self.use_json:
try: try:
data.update(json.loads(body_content)) data.update(json.loads(body_content))
except ValueError: except (TypeError, ValueError):
# TypeError is JSON object is not a dict
# ValueError if we could not parse JSON
_LOGGER.exception("Exception parsing JSON: %s", _LOGGER.exception("Exception parsing JSON: %s",
body_content) body_content)
self._message( self._json_message(
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY) "Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
return return
else:
data.update({key: value[-1] for key, value in
parse_qs(body_content).items()})
api_password = self.headers.get(rem.AUTH_HEADER) api_password = self.headers.get(rem.AUTH_HEADER)
if not api_password and 'api_password' in data: if not api_password and DATA_API_PASSWORD in data:
api_password = data['api_password'] api_password = data[DATA_API_PASSWORD]
if '_METHOD' in data: if '_METHOD' in data:
method = data.pop('_METHOD') method = data.pop('_METHOD')
@ -307,7 +308,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
# For API calls we need a valid password # For API calls we need a valid password
if self.use_json and api_password != self.server.api_password: if self.use_json and api_password != self.server.api_password:
self._message( self._json_message(
"API password missing or incorrect.", HTTP_UNAUTHORIZED) "API password missing or incorrect.", HTTP_UNAUTHORIZED)
else: else:
@ -315,9 +316,11 @@ class RequestHandler(SimpleHTTPRequestHandler):
elif path_matched_but_not_method: elif path_matched_but_not_method:
self.send_response(HTTP_METHOD_NOT_ALLOWED) self.send_response(HTTP_METHOD_NOT_ALLOWED)
self.end_headers()
else: else:
self.send_response(HTTP_NOT_FOUND) self.send_response(HTTP_NOT_FOUND)
self.end_headers()
def do_HEAD(self): # pylint: disable=invalid-name def do_HEAD(self): # pylint: disable=invalid-name
""" HEAD request handler. """ """ HEAD request handler. """
@ -377,7 +380,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def _handle_get_api(self, path_match, data): def _handle_get_api(self, path_match, data):
""" Renders the debug interface. """ """ Renders the debug interface. """
self._message("API running.") self._json_message("API running.")
# pylint: disable=unused-argument # pylint: disable=unused-argument
def _handle_get_api_states(self, path_match, data): def _handle_get_api_states(self, path_match, data):
@ -394,7 +397,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
if state: if state:
self._write_json(state) self._write_json(state)
else: else:
self._message("State does not exist.", HTTP_NOT_FOUND) self._json_message("State does not exist.", HTTP_NOT_FOUND)
def _handle_post_state_entity(self, path_match, data): def _handle_post_state_entity(self, path_match, data):
""" Handles updating the state of an entity. """ Handles updating the state of an entity.
@ -407,7 +410,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
try: try:
new_state = data['state'] new_state = data['state']
except KeyError: except KeyError:
self._message("state not specified", HTTP_BAD_REQUEST) self._json_message("state not specified", HTTP_BAD_REQUEST)
return return
attributes = data['attributes'] if 'attributes' in data else None attributes = data['attributes'] if 'attributes' in data else None
@ -417,8 +420,6 @@ class RequestHandler(SimpleHTTPRequestHandler):
# Write state # Write state
self.server.hass.states.set(entity_id, new_state, attributes) self.server.hass.states.set(entity_id, new_state, attributes)
# Return state if json, else redirect to main page
if self.use_json:
state = self.server.hass.states.get(entity_id) state = self.server.hass.states.get(entity_id)
status_code = HTTP_CREATED if is_new_state else HTTP_OK status_code = HTTP_CREATED if is_new_state else HTTP_OK
@ -427,9 +428,6 @@ class RequestHandler(SimpleHTTPRequestHandler):
state.as_dict(), state.as_dict(),
status_code=status_code, status_code=status_code,
location=rem.URL_API_STATES_ENTITY.format(entity_id)) location=rem.URL_API_STATES_ENTITY.format(entity_id))
else:
self._message(
"State of {} changed to {}".format(entity_id, new_state))
def _handle_get_api_events(self, path_match, data): def _handle_get_api_events(self, path_match, data):
""" Handles getting overview of event listeners. """ """ Handles getting overview of event listeners. """
@ -448,7 +446,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
event_type = path_match.group('event_type') event_type = path_match.group('event_type')
if event_data is not None and not isinstance(event_data, dict): if event_data is not None and not isinstance(event_data, dict):
self._message("event_data should be an object", self._json_message("event_data should be an object",
HTTP_UNPROCESSABLE_ENTITY) HTTP_UNPROCESSABLE_ENTITY)
event_origin = ha.EventOrigin.remote event_origin = ha.EventOrigin.remote
@ -464,7 +462,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.server.hass.bus.fire(event_type, event_data, event_origin) self.server.hass.bus.fire(event_type, event_data, event_origin)
self._message("Event {} fired.".format(event_type)) self._json_message("Event {} fired.".format(event_type))
def _handle_get_api_services(self, path_match, data): def _handle_get_api_services(self, path_match, data):
""" Handles getting overview of services. """ """ Handles getting overview of services. """
@ -485,7 +483,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.server.hass.call_service(domain, service, data) self.server.hass.call_service(domain, service, data)
self._message("Service {}/{} called.".format(domain, service)) self._json_message("Service {}/{} called.".format(domain, service))
# pylint: disable=invalid-name # pylint: disable=invalid-name
def _handle_post_api_event_forward(self, path_match, data): def _handle_post_api_event_forward(self, path_match, data):
@ -495,26 +493,31 @@ class RequestHandler(SimpleHTTPRequestHandler):
host = data['host'] host = data['host']
api_password = data['api_password'] api_password = data['api_password']
except KeyError: except KeyError:
self._message("No host or api_password received.", self._json_message("No host or api_password received.",
HTTP_BAD_REQUEST) HTTP_BAD_REQUEST)
return return
try: try:
port = int(data['port']) if 'port' in data else None port = int(data['port']) if 'port' in data else None
except ValueError: except ValueError:
self._message( self._json_message(
"Invalid value received for port", HTTP_UNPROCESSABLE_ENTITY) "Invalid value received for port", HTTP_UNPROCESSABLE_ENTITY)
return return
api = rem.API(host, api_password, port)
if not api.validate_api():
self._json_message(
"Unable to validate API", HTTP_UNPROCESSABLE_ENTITY)
return
if self.server.event_forwarder is None: if self.server.event_forwarder is None:
self.server.event_forwarder = \ self.server.event_forwarder = \
rem.EventForwarder(self.server.hass) rem.EventForwarder(self.server.hass)
api = rem.API(host, api_password, port)
self.server.event_forwarder.connect(api) self.server.event_forwarder.connect(api)
self._message("Event forwarding setup.") self._json_message("Event forwarding setup.")
def _handle_delete_api_event_forward(self, path_match, data): def _handle_delete_api_event_forward(self, path_match, data):
""" Handles deleting an event forwarding target. """ """ Handles deleting an event forwarding target. """
@ -522,14 +525,14 @@ class RequestHandler(SimpleHTTPRequestHandler):
try: try:
host = data['host'] host = data['host']
except KeyError: except KeyError:
self._message("No host received.", self._json_message("No host received.",
HTTP_BAD_REQUEST) HTTP_BAD_REQUEST)
return return
try: try:
port = int(data['port']) if 'port' in data else None port = int(data['port']) if 'port' in data else None
except ValueError: except ValueError:
self._message( self._json_message(
"Invalid value received for port", HTTP_UNPROCESSABLE_ENTITY) "Invalid value received for port", HTTP_UNPROCESSABLE_ENTITY)
return return
@ -538,7 +541,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.server.event_forwarder.disconnect(api) self.server.event_forwarder.disconnect(api)
self._message("Event forwarding cancelled.") self._json_message("Event forwarding cancelled.")
def _handle_get_static(self, path_match, data): def _handle_get_static(self, path_match, data):
""" Returns a static file. """ """ Returns a static file. """
@ -585,7 +588,10 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.end_headers() self.end_headers()
if do_gzip: if self.command == 'HEAD':
return
elif do_gzip:
self.wfile.write(gzip_data) self.wfile.write(gzip_data)
else: else:
@ -599,22 +605,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
if inp: if inp:
inp.close() inp.close()
def _message(self, message, status_code=HTTP_OK): def _json_message(self, message, status_code=HTTP_OK):
""" Helper method to return a message to the caller. """ """ Helper method to return a message to the caller. """
if self.use_json:
self._write_json({'message': message}, status_code=status_code) self._write_json({'message': message}, status_code=status_code)
else:
self.send_error(status_code, message)
def _redirect(self, location):
""" Helper method to redirect caller. """
self.send_response(HTTP_MOVED_PERMANENTLY)
self.send_header(
"Location", "{}?api_password={}".format(
location, self.server.api_password))
self.end_headers()
def _write_json(self, data=None, status_code=HTTP_OK, location=None): def _write_json(self, data=None, status_code=HTTP_OK, location=None):
""" Helper method to return JSON to the caller. """ """ Helper method to return JSON to the caller. """

View File

@ -34,6 +34,7 @@ URL_API_EVENT_FORWARD = "/api/event_forwarding"
METHOD_GET = "get" METHOD_GET = "get"
METHOD_POST = "post" METHOD_POST = "post"
METHOD_DELETE = "delete"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -94,6 +95,10 @@ class API(object):
_LOGGER.exception(error) _LOGGER.exception(error)
raise ha.HomeAssistantError(error) raise ha.HomeAssistantError(error)
def __repr__(self):
return "API({}, {}, {})".format(
self.host, self.api_password, self.port)
class HomeAssistant(ha.HomeAssistant): class HomeAssistant(ha.HomeAssistant):
""" Home Assistant that forwards work. """ """ Home Assistant that forwards work. """
@ -122,19 +127,23 @@ class HomeAssistant(ha.HomeAssistant):
import random import random
# pylint: disable=too-many-format-args # pylint: disable=too-many-format-args
random_password = '%030x'.format(random.randrange(16**30)) random_password = '{:30}'.format(random.randrange(16**30))
http.setup( http.setup(
self, {http.DOMAIN: {http.CONF_API_PASSWORD: random_password}}) self, {http.DOMAIN: {http.CONF_API_PASSWORD: random_password}})
ha.Timer(self) ha.Timer(self)
# Setup that events from remote_api get forwarded to local_api
connect_remote_events(self.remote_api, self.local_api)
self.bus.fire(ha.EVENT_HOMEASSISTANT_START, self.bus.fire(ha.EVENT_HOMEASSISTANT_START,
origin=ha.EventOrigin.remote) origin=ha.EventOrigin.remote)
# Setup that events from remote_api get forwarded to local_api
# Do this after we fire START, otherwise HTTP is not started
if not connect_remote_events(self.remote_api, self.local_api):
raise ha.HomeAssistantError((
'Could not setup event forwarding from api {} to '
'local api {}').format(self.remote_api, self.local_api))
def stop(self): def stop(self):
""" Stops Home Assistant and shuts down all threads. """ """ Stops Home Assistant and shuts down all threads. """
_LOGGER.info("Stopping") _LOGGER.info("Stopping")
@ -289,30 +298,51 @@ def validate_api(api):
def connect_remote_events(from_api, to_api): def connect_remote_events(from_api, to_api):
""" Sets up from_api to forward all events to to_api. """ """ Sets up from_api to forward all events to to_api. """
data = {'host': to_api.host, 'api_password': to_api.api_password} data = {
'host': to_api.host,
if to_api.port is not None: 'api_password': to_api.api_password,
data['port'] = to_api.port 'port': to_api.port
}
try: try:
from_api(METHOD_POST, URL_API_EVENT_FORWARD, data) req = from_api(METHOD_POST, URL_API_EVENT_FORWARD, data)
if req.status_code == 200:
return True
else:
_LOGGER.error(
"Error settign up event forwarding: %s - %s",
req.status_code, req.text)
return False
except ha.HomeAssistantError: except ha.HomeAssistantError:
pass _LOGGER.exception("Error setting up event forwarding")
return False
def disconnect_remote_events(from_api, to_api): def disconnect_remote_events(from_api, to_api):
""" Disconnects forwarding events from from_api to to_api. """ """ Disconnects forwarding events from from_api to to_api. """
data = {'host': to_api.host, '_METHOD': 'DELETE'} data = {
'host': to_api.host,
if to_api.port is not None: 'port': to_api.port
data['port'] = to_api.port }
try: try:
from_api(METHOD_POST, URL_API_EVENT_FORWARD, data) req = from_api(METHOD_DELETE, URL_API_EVENT_FORWARD, data)
if req.status_code == 200:
return True
else:
_LOGGER.error(
"Error removing event forwarding: %s - %s",
req.status_code, req.text)
return False
except ha.HomeAssistantError: except ha.HomeAssistantError:
pass _LOGGER.exception("Error removing an event forwarder")
return False
def get_event_listeners(api): def get_event_listeners(api):

View File

@ -52,30 +52,50 @@ def setUpModule(): # pylint: disable=invalid-name
def tearDownModule(): # pylint: disable=invalid-name def tearDownModule(): # pylint: disable=invalid-name
""" Stops the Home Assistant server. """ """ Stops the Home Assistant server. """
global hass
hass.stop() hass.stop()
class TestHTTP(unittest.TestCase): class TestHTTP(unittest.TestCase):
""" Test the HTTP debug interface and API. """ """ Test the HTTP debug interface and API. """
def test_get_frontend(self): def test_setup(self):
""" Test http.setup. """
self.assertFalse(http.setup(hass, {}))
self.assertFalse(http.setup(hass, {http.DOMAIN: {}}))
def test_frontend_and_static(self):
""" Tests if we can get the frontend. """ """ Tests if we can get the frontend. """
req = requests.get(_url("")) req = requests.get(_url(""))
self.assertEqual(200, req.status_code) self.assertEqual(200, req.status_code)
# Test we can retrieve frontend.js
frontendjs = re.search( frontendjs = re.search(
r'(?P<app>\/static\/frontend-[A-Za-z0-9]{32}.html)', r'(?P<app>\/static\/frontend-[A-Za-z0-9]{32}.html)',
req.text).groups(0)[0] req.text)
self.assertIsNotNone(frontendjs) self.assertIsNotNone(frontendjs)
req = requests.get(_url(frontendjs)) req = requests.head(_url(frontendjs.groups(0)[0]))
self.assertEqual(200, req.status_code) self.assertEqual(200, req.status_code)
# Test auto filling in api password
req = requests.get(
_url("?{}={}".format(http.DATA_API_PASSWORD, API_PASSWORD)))
self.assertEqual(200, req.status_code)
auth_text = re.search(r"auth='{}'".format(API_PASSWORD), req.text)
self.assertIsNotNone(auth_text)
# Test 404
self.assertEqual(404, requests.get(_url("/not-existing")).status_code)
# Test we cannot POST to /
self.assertEqual(405, requests.post(_url("")).status_code)
def test_api_password(self): def test_api_password(self):
""" Test if we get access denied if we omit or provide """ Test if we get access denied if we omit or provide
a wrong api password. """ a wrong api password. """
@ -127,8 +147,8 @@ class TestHTTP(unittest.TestCase):
hass.states.set("test.test", "not_to_be_set") hass.states.set("test.test", "not_to_be_set")
requests.post(_url(remote.URL_API_STATES_ENTITY.format("test.test")), requests.post(_url(remote.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "debug_state_change2", data=json.dumps({"state": "debug_state_change2"}),
"api_password": API_PASSWORD})) headers=HA_HEADERS)
self.assertEqual("debug_state_change2", self.assertEqual("debug_state_change2",
hass.states.get("test.test").state) hass.states.get("test.test").state)
@ -143,8 +163,8 @@ class TestHTTP(unittest.TestCase):
req = requests.post( req = requests.post(
_url(remote.URL_API_STATES_ENTITY.format( _url(remote.URL_API_STATES_ENTITY.format(
"test_entity.that_does_not_exist")), "test_entity.that_does_not_exist")),
data=json.dumps({"state": new_state, data=json.dumps({'state': new_state}),
"api_password": API_PASSWORD})) headers=HA_HEADERS)
cur_state = (hass.states. cur_state = (hass.states.
get("test_entity.that_does_not_exist").state) get("test_entity.that_does_not_exist").state)
@ -152,6 +172,20 @@ class TestHTTP(unittest.TestCase):
self.assertEqual(201, req.status_code) self.assertEqual(201, req.status_code)
self.assertEqual(cur_state, new_state) self.assertEqual(cur_state, new_state)
# pylint: disable=invalid-name
def test_api_state_change_with_bad_data(self):
""" Test if API sends appropriate error if we omit state. """
new_state = "debug_state_change"
req = requests.post(
_url(remote.URL_API_STATES_ENTITY.format(
"test_entity.that_does_not_exist")),
data=json.dumps({}),
headers=HA_HEADERS)
self.assertEqual(400, req.status_code)
# pylint: disable=invalid-name # pylint: disable=invalid-name
def test_api_fire_event_with_no_data(self): def test_api_fire_event_with_no_data(self):
""" Test if the API allows us to fire an event. """ """ Test if the API allows us to fire an event. """
@ -214,6 +248,17 @@ class TestHTTP(unittest.TestCase):
self.assertEqual(422, req.status_code) self.assertEqual(422, req.status_code)
self.assertEqual(0, len(test_value)) self.assertEqual(0, len(test_value))
# Try now with valid but unusable JSON
req = requests.post(
_url(remote.URL_API_EVENTS_EVENT.format("test_event_bad_data")),
data=json.dumps([1, 2, 3]),
headers=HA_HEADERS)
hass._pool.block_till_done()
self.assertEqual(422, req.status_code)
self.assertEqual(0, len(test_value))
def test_api_get_event_listeners(self): def test_api_get_event_listeners(self):
""" Test if we can get the list of events being listened for. """ """ Test if we can get the list of events being listened for. """
req = requests.get(_url(remote.URL_API_EVENTS), req = requests.get(_url(remote.URL_API_EVENTS),
@ -279,3 +324,79 @@ class TestHTTP(unittest.TestCase):
hass._pool.block_till_done() hass._pool.block_till_done()
self.assertEqual(1, len(test_value)) self.assertEqual(1, len(test_value))
def test_api_event_forward(self):
""" Test setting up event forwarding. """
req = requests.post(
_url(remote.URL_API_EVENT_FORWARD),
headers=HA_HEADERS)
self.assertEqual(400, req.status_code)
req = requests.post(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({'host': '127.0.0.1'}),
headers=HA_HEADERS)
self.assertEqual(400, req.status_code)
req = requests.post(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({'api_password': 'bla-di-bla'}),
headers=HA_HEADERS)
self.assertEqual(400, req.status_code)
req = requests.post(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({
'api_password': 'bla-di-bla',
'host': '127.0.0.1',
'port': 'abcd'
}),
headers=HA_HEADERS)
self.assertEqual(422, req.status_code)
req = requests.post(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({
'api_password': 'bla-di-bla',
'host': '127.0.0.1',
'port': '8125'
}),
headers=HA_HEADERS)
self.assertEqual(422, req.status_code)
# Setup a real one
req = requests.post(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({
'api_password': API_PASSWORD,
'host': '127.0.0.1',
'port': SERVER_PORT
}),
headers=HA_HEADERS)
self.assertEqual(200, req.status_code)
# Delete it again..
req = requests.delete(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({}),
headers=HA_HEADERS)
self.assertEqual(400, req.status_code)
req = requests.delete(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({
'host': '127.0.0.1',
'port': 'abcd'
}),
headers=HA_HEADERS)
self.assertEqual(422, req.status_code)
req = requests.delete(
_url(remote.URL_API_EVENT_FORWARD),
data=json.dumps({
'host': '127.0.0.1',
'port': SERVER_PORT
}),
headers=HA_HEADERS)
self.assertEqual(200, req.status_code)

View File

@ -15,7 +15,7 @@ import homeassistant.components.http as http
API_PASSWORD = "test1234" API_PASSWORD = "test1234"
HTTP_BASE_URL = "http://127.0.0.1:{}".format(remote.SERVER_PORT) HTTP_BASE_URL = "http://127.0.0.1:8122"
HA_HEADERS = {remote.AUTH_HEADER: API_PASSWORD} HA_HEADERS = {remote.AUTH_HEADER: API_PASSWORD}