Added remote EventBus and StateMachine classes

This commit is contained in:
Paulus Schoutsen 2013-10-25 11:05:58 +01:00
parent 5ae08c6f0f
commit 867966234f
3 changed files with 319 additions and 18 deletions

View File

@ -27,6 +27,7 @@ Fires an 'event_name' event containing data from 'event_data'
import json import json
import threading import threading
import itertools
import logging import logging
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from urlparse import urlparse, parse_qs from urlparse import urlparse, parse_qs
@ -179,7 +180,8 @@ class RequestHandler(BaseHTTPRequestHandler):
else: else:
self.send_response(404) self.send_response(404)
def do_POST(self): # pylint: disable=invalid-name, too-many-branches # pylint: disable=invalid-name, too-many-branches, too-many-statements
def do_POST(self):
""" Handle incoming POST requests. """ """ Handle incoming POST requests. """
length = int(self.headers['Content-Length']) length = int(self.headers['Content-Length'])
@ -196,24 +198,69 @@ class RequestHandler(BaseHTTPRequestHandler):
given_api_password = post_data.get("api_password", [''])[0] given_api_password = post_data.get("api_password", [''])[0]
# Action to change the state # Action to change the state
if action == "state/change": if action == "state/categories":
if self._verify_api_password(given_api_password, use_json):
self._response(use_json, "State categories",
json_data=
{'categories': self.server.statemachine.categories})
elif action == "state/get":
if self._verify_api_password(given_api_password, use_json):
try:
category = post_data['category'][0]
state = self.server.statemachine.get_state(category)
self._response(use_json,
"State of {}".format(category),
json_data={'category': category,
'state': state.state,
'last_changed':
util.datetime_to_str(state.last_changed),
'attributes': state.attributes
})
except KeyError:
# If category or new_state don't exist in post data
self._response(use_json, "Invalid state received.",
MESSAGE_STATUS_ERROR)
elif action == "state/change":
if self._verify_api_password(given_api_password, use_json): if self._verify_api_password(given_api_password, use_json):
try: try:
changed = [] changed = []
for category, new_state in zip(post_data['category'], for idx, category, new_state in zip(itertools.count(),
post_data['new_state']): post_data['category'],
post_data['new_state']
):
self.server.statemachine.set_state(category, new_state) # See if we also received attributes for this state
try:
attributes = json.loads(
post_data['attributes'][idx])
except KeyError:
# Happens if key 'attributes' or idx does not exist
attributes = None
self.server.statemachine.set_state(category,
new_state,
attributes)
changed.append("{}={}".format(category, new_state)) changed.append("{}={}".format(category, new_state))
self._message(use_json, "States changed: {}". self._response(use_json, "States changed: {}".
format( ", ".join(changed) ) ) format( ", ".join(changed) ) )
except KeyError: except KeyError:
# If category or new_state don't exist in post data # If category or new_state don't exist in post data
self._message(use_json, "Invalid state received.", self._response(use_json, "Invalid parameters received.",
MESSAGE_STATUS_ERROR)
except ValueError:
# If json.loads doesn't understand the attributes
self._response(use_json, "Invalid state data received.",
MESSAGE_STATUS_ERROR) MESSAGE_STATUS_ERROR)
# Action to fire an event # Action to fire an event
@ -232,17 +279,17 @@ class RequestHandler(BaseHTTPRequestHandler):
self.server.eventbus.fire(event_name, event_data) self.server.eventbus.fire(event_name, event_data)
self._message(use_json, "Event {} fired.". self._response(use_json, "Event {} fired.".
format(event_name)) format(event_name))
except ValueError: except ValueError:
# If JSON decode error # If JSON decode error
self._message(use_json, "Invalid event received (1).", self._response(use_json, "Invalid event received (1).",
MESSAGE_STATUS_ERROR) MESSAGE_STATUS_ERROR)
except KeyError: except KeyError:
# If "event_name" not in post_data # If "event_name" not in post_data
self._message(use_json, "Invalid event received (2).", self._response(use_json, "Invalid event received (2).",
MESSAGE_STATUS_ERROR) MESSAGE_STATUS_ERROR)
else: else:
@ -256,7 +303,7 @@ class RequestHandler(BaseHTTPRequestHandler):
return True return True
elif use_json: elif use_json:
self._message(True, "API password missing or incorrect.", self._response(True, "API password missing or incorrect.",
MESSAGE_STATUS_UNAUTHORIZED) MESSAGE_STATUS_UNAUTHORIZED)
else: else:
@ -277,7 +324,8 @@ class RequestHandler(BaseHTTPRequestHandler):
return False return False
def _message(self, use_json, message, status=MESSAGE_STATUS_OK): def _response(self, use_json, message,
status=MESSAGE_STATUS_OK, json_data=None):
""" Helper method to show a message to the user. """ """ Helper method to show a message to the user. """
log_message = "{}: {}".format(status, message) log_message = "{}: {}".format(status, message)
@ -295,7 +343,11 @@ class RequestHandler(BaseHTTPRequestHandler):
self.send_header('Content-type','application/json') self.send_header('Content-type','application/json')
self.end_headers() self.end_headers()
self.wfile.write(json.dumps({'status': status, 'message':message})) json_data = json_data or {}
json_data['status'] = status
json_data['message'] = message
self.wfile.write(json.dumps(json_data))
else: else:
self.server.flash_message = message self.server.flash_message = message

147
homeassistant/remote.py Normal file
View File

@ -0,0 +1,147 @@
"""
homeassistant.remote
~~~~~~~~~~~~~~~~~~~~
A module containing drop in replacements for core parts that will interface
with a remote instance of home assistant.
"""
import threading
import logging
import json
import requests
import homeassistant
import homeassistant.httpinterface as httpinterface
import homeassistant.util as util
def _setup_call_api(host, port, base_path, api_password):
""" Helper method to setup a call api method. """
port = port or httpinterface.SERVER_PORT
base_url = "http://{}:{}/api/{}".format(host, port, base_path)
def _call_api(action, data=None):
""" Makes a call to the Home Assistant api. """
data = data or {}
data['api_password'] = api_password
return requests.post(base_url + action, data=data)
return _call_api
class EventBus(homeassistant.EventBus):
""" Drop-in replacement for a normal eventbus that will forward events to
a remote eventbus.
"""
def __init__(self, host, api_password, port=None):
homeassistant.EventBus.__init__(self)
self._call_api = _setup_call_api(host, port, "event/", api_password)
self.logger = logging.getLogger(__name__)
def fire(self, event_type, event_data=None):
""" Fire an event. """
if not event_data:
event_data = {}
data = {'event_name': event_type,
'event_data': json.dumps(event_data)}
try:
self._call_api("fire", data)
except requests.exceptions.ConnectionError:
self.logger.exception("EventBus:Error connecting to server")
def listen(self, event_type, listener):
""" Not implemented for remote eventbus.
Will throw NotImplementedError. """
raise NotImplementedError
def remove_listener(self, event_type, listener):
""" Not implemented for remote eventbus.
Will throw NotImplementedError. """
raise NotImplementedError
class StateMachine(homeassistant.StateMachine):
""" Drop-in replacement for a normal statemachine that communicates with a
remote statemachine.
"""
def __init__(self, host, api_password, port=None):
homeassistant.StateMachine.__init__(self, None)
self._call_api = _setup_call_api(host, port, "state/", api_password)
self.lock = threading.Lock()
self.logger = logging.getLogger(__name__)
@property
def categories(self):
""" List of categories which states are being tracked. """
try:
req = self._call_api("categories")
return req.json()['categories']
except requests.exceptions.ConnectionError:
self.logger.exception("StateMachine:Error connecting to server")
return []
except ValueError: # If req.json() can't parse the json
self.logger.exception("StateMachine:Got unexpected result")
return []
def set_state(self, category, new_state, attributes=None):
""" Set the state of a category, add category if it does not exist.
Attributes is an optional dict to specify attributes of this state. """
attributes = attributes or {}
self.lock.acquire()
data = {'category': category,
'new_state': new_state,
'attributes': json.dumps(attributes)}
try:
self._call_api('change', data)
except requests.exceptions.ConnectionError:
# Raise a Home Assistant error??
self.logger.exception("StateMachine:Error connecting to server")
finally:
self.lock.release()
def get_state(self, category):
""" Returns a tuple (state,last_changed) describing
the state of the specified category. """
try:
req = self._call_api("get", {'category': category})
data = req.json()
return homeassistant.State(data['state'],
util.str_to_datetime(data['last_changed']),
data['attributes'])
except requests.exceptions.ConnectionError:
self.logger.exception("StateMachine:Error connecting to server")
except ValueError: # If req.json() can't parse the json
self.logger.exception("StateMachine:Got unexpected result")
return []

View File

@ -12,7 +12,9 @@ import time
import requests import requests
import homeassistant as ha import homeassistant as ha
import homeassistant.remote as remote
import homeassistant.httpinterface as httpinterface import homeassistant.httpinterface as httpinterface
import homeassistant.util as util
API_PASSWORD = "test1234" API_PASSWORD = "test1234"
@ -45,6 +47,8 @@ class TestHTTPInterface(unittest.TestCase):
""" things to be run when tests are started. """ """ things to be run when tests are started. """
cls.eventbus = ha.EventBus() cls.eventbus = ha.EventBus()
cls.statemachine = ha.StateMachine(cls.eventbus) cls.statemachine = ha.StateMachine(cls.eventbus)
cls.remote_sm = remote.StateMachine("127.0.0.1", API_PASSWORD)
cls.remote_eb = remote.EventBus("127.0.0.1", API_PASSWORD)
def test_debug_interface(self): def test_debug_interface(self):
""" Test if we can login by comparing not logged in screen to """ Test if we can login by comparing not logged in screen to
@ -78,6 +82,35 @@ class TestHTTPInterface(unittest.TestCase):
self.assertEqual(req.status_code, 401) self.assertEqual(req.status_code, 401)
def test_api_list_state_categories(self):
""" Test if the debug interface allows us to list state categories. """
req = requests.post("{}/api/state/categories".format(HTTP_BASE_URL),
data={"api_password":API_PASSWORD})
data = req.json()
self.assertEqual(self.statemachine.categories,
data['categories'])
def test_api_get_state(self):
""" Test if the debug interface allows us to list state categories. """
req = requests.post("{}/api/state/get".format(HTTP_BASE_URL),
data={"api_password":API_PASSWORD,
"category": "test"})
data = req.json()
state = self.statemachine.get_state("test")
trunc_last_changed = state.last_changed.replace(microsecond=0)
self.assertEqual(data['category'], "test")
self.assertEqual(data['state'], state.state)
self.assertEqual(util.str_to_datetime(data['last_changed']),
trunc_last_changed)
self.assertEqual(data['attributes'], state.attributes)
def test_api_state_change(self): def test_api_state_change(self):
""" Test if we can change the state of a category that exists. """ """ Test if we can change the state of a category that exists. """
@ -91,6 +124,38 @@ class TestHTTPInterface(unittest.TestCase):
self.assertEqual(self.statemachine.get_state("test").state, self.assertEqual(self.statemachine.get_state("test").state,
"debug_state_change2") "debug_state_change2")
# pylint: disable=invalid-name
def test_remote_sm_list_state_categories(self):
""" Test if the debug interface allows us to list state categories. """
self.assertEqual(self.statemachine.categories,
self.remote_sm.categories)
def test_remote_sm_get_state(self):
""" Test if the debug interface allows us to list state categories. """
remote_state = self.remote_sm.get_state("test")
state = self.statemachine.get_state("test")
trunc_last_changed = state.last_changed.replace(microsecond=0)
self.assertEqual(remote_state.state, state.state)
self.assertEqual(remote_state.last_changed, trunc_last_changed)
self.assertEqual(remote_state.attributes, state.attributes)
def test_remote_sm_state_change(self):
""" Test if we can change the state of a category that exists. """
self.remote_sm.set_state("test", "set_remotely", {"test": 1})
state = self.statemachine.get_state("test")
self.assertEqual(state.state, "set_remotely")
self.assertEqual(state.attributes['test'], 1)
def test_api_multiple_state_change(self): def test_api_multiple_state_change(self):
""" Test if we can change multiple states in 1 request. """ """ Test if we can change multiple states in 1 request. """
@ -134,7 +199,7 @@ class TestHTTPInterface(unittest.TestCase):
""" Helper method that will verify our event got called. """ """ Helper method that will verify our event got called. """
test_value.append(1) test_value.append(1)
self.eventbus.listen("test_event_no_data", listener) self.eventbus.listen_once("test_event_no_data", listener)
requests.post("{}/api/event/fire".format(HTTP_BASE_URL), requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
data={"event_name":"test_event_no_data", data={"event_name":"test_event_no_data",
@ -146,7 +211,6 @@ class TestHTTPInterface(unittest.TestCase):
self.assertEqual(len(test_value), 1) self.assertEqual(len(test_value), 1)
# pylint: disable=invalid-name # pylint: disable=invalid-name
def test_api_fire_event_with_data(self): def test_api_fire_event_with_data(self):
""" Test if the API allows us to fire an event. """ """ Test if the API allows us to fire an event. """
@ -158,7 +222,7 @@ class TestHTTPInterface(unittest.TestCase):
if "test" in event.data: if "test" in event.data:
test_value.append(1) test_value.append(1)
self.eventbus.listen("test_event_with_data", listener) self.eventbus.listen_once("test_event_with_data", listener)
requests.post("{}/api/event/fire".format(HTTP_BASE_URL), requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
data={"event_name":"test_event_with_data", data={"event_name":"test_event_with_data",
@ -182,7 +246,7 @@ class TestHTTPInterface(unittest.TestCase):
if "test" in event.data: if "test" in event.data:
test_value.append(1) test_value.append(1)
self.eventbus.listen("test_event_with_data", listener) self.eventbus.listen_once("test_event_with_data", listener)
requests.post("{}/api/event/fire".format(HTTP_BASE_URL), requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
data={"api_password":API_PASSWORD}) data={"api_password":API_PASSWORD})
@ -202,7 +266,7 @@ class TestHTTPInterface(unittest.TestCase):
""" Helper method that will verify our event got called. """ """ Helper method that will verify our event got called. """
test_value.append(1) test_value.append(1)
self.eventbus.listen("test_event_with_bad_data", listener) self.eventbus.listen_once("test_event_with_bad_data", listener)
req = requests.post("{}/api/event/fire".format(HTTP_BASE_URL), req = requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
data={"event_name":"test_event_with_bad_data", data={"event_name":"test_event_with_bad_data",
@ -215,3 +279,41 @@ class TestHTTPInterface(unittest.TestCase):
self.assertEqual(req.status_code, 400) self.assertEqual(req.status_code, 400)
self.assertEqual(len(test_value), 0) self.assertEqual(len(test_value), 0)
# pylint: disable=invalid-name
def test_remote_eb_fire_event_with_no_data(self):
""" Test if the remote eventbus allows us to fire an event. """
test_value = []
def listener(event): # pylint: disable=unused-argument
""" Helper method that will verify our event got called. """
test_value.append(1)
self.eventbus.listen_once("test_event_no_data", listener)
self.remote_eb.fire("test_event_no_data")
# Allow the event to take place
time.sleep(1)
self.assertEqual(len(test_value), 1)
# pylint: disable=invalid-name
def test_remote_eb_fire_event_with_data(self):
""" Test if the remote eventbus allows us to fire an event. """
test_value = []
def listener(event): # pylint: disable=unused-argument
""" Helper method that will verify our event got called. """
if event.data["test"] == 1:
test_value.append(1)
self.eventbus.listen_once("test_event_with_data", listener)
self.remote_eb.fire("test_event_with_data", {"test": 1})
# Allow the event to take place
time.sleep(1)
self.assertEqual(len(test_value), 1)