PyLint is a lot happier about the code now.

This commit is contained in:
Paulus Schoutsen 2013-10-07 23:55:19 -07:00
parent 9ac8cf7a27
commit 32b357f5e1
7 changed files with 462 additions and 251 deletions

View File

@ -50,7 +50,9 @@ def ensure_list(parameter):
def matcher(subject, pattern): def matcher(subject, pattern):
""" Returns True if subject matches the pattern. """ Returns True if subject matches the pattern.
Pattern is either a list of allowed subjects or a '*'. """
Pattern is either a list of allowed subjects or a '*'.
"""
return '*' in pattern or subject in pattern return '*' in pattern or subject in pattern
def track_state_change(eventbus, category, from_state, to_state, action): def track_state_change(eventbus, category, from_state, to_state, action):
@ -66,21 +68,28 @@ def track_state_change(eventbus, category, from_state, to_state, action):
matcher(event.data['old_state'].state, from_state) and \ matcher(event.data['old_state'].state, from_state) and \
matcher(event.data['new_state'].state, to_state): matcher(event.data['new_state'].state, to_state):
action(event.data['category'], event.data['old_state'], event.data['new_state']) action(event.data['category'],
event.data['old_state'],
event.data['new_state'])
eventbus.listen(EVENT_STATE_CHANGED, listener) eventbus.listen(EVENT_STATE_CHANGED, listener)
def track_time_change(eventbus, action, year='*', month='*', day='*', hour='*', minute='*', second='*', point_in_time=None, listen_once=False): # pylint: disable=too-many-arguments
def track_time_change(eventbus, action,
year='*', month='*', day='*',
hour='*', minute='*', second='*',
point_in_time=None, listen_once=False):
""" Adds a listener that will listen for a specified or matching time. """ """ Adds a listener that will listen for a specified or matching time. """
year, month, day = ensure_list(year), ensure_list(month), ensure_list(day) year, month, day = ensure_list(year), ensure_list(month), ensure_list(day)
hour, minute, second = ensure_list(hour), ensure_list(minute), ensure_list(second) hour, minute = ensure_list(hour), ensure_list(minute)
second = ensure_list(second)
def listener(event): def listener(event):
""" Listens for matching time_changed events. """ """ Listens for matching time_changed events. """
assert isinstance(event, Event), "event needs to be of Event type" assert isinstance(event, Event), "event needs to be of Event type"
if (point_in_time is not None and event.data['now'] > point_in_time) or \ if (point_in_time and event.data['now'] > point_in_time) or \
(point_in_time is None and \ (not point_in_time and \
matcher(event.data['now'].year, year) and \ matcher(event.data['now'].year, year) and \
matcher(event.data['now'].month, month) and \ matcher(event.data['now'].month, month) and \
matcher(event.data['now'].day, day) and \ matcher(event.data['now'].day, day) and \
@ -88,7 +97,8 @@ def track_time_change(eventbus, action, year='*', month='*', day='*', hour='*',
matcher(event.data['now'].minute, minute) and \ matcher(event.data['now'].minute, minute) and \
matcher(event.data['now'].second, second)): matcher(event.data['now'].second, second)):
# point_in_time are exact points in time so we always remove it after fire # point_in_time are exact points in time
# so we always remove it after fire
event.remove_listener = listen_once or point_in_time is not None event.remove_listener = listen_once or point_in_time is not None
action(event.data['now']) action(event.data['now'])
@ -96,7 +106,7 @@ def track_time_change(eventbus, action, year='*', month='*', day='*', hour='*',
eventbus.listen(EVENT_TIME_CHANGED, listener) eventbus.listen(EVENT_TIME_CHANGED, listener)
class EventBus(object): class EventBus(object):
""" Class provides an eventbus. Allows code to listen for events and fire them. """ """ Class that allows code to listen for- and fire events. """
def __init__(self): def __init__(self):
self.listeners = defaultdict(list) self.listeners = defaultdict(list)
@ -105,19 +115,22 @@ class EventBus(object):
def fire(self, event): def fire(self, event):
""" Fire an event. """ """ Fire an event. """
assert isinstance(event, Event), "event needs to be an instance of Event" assert isinstance(event, Event), \
"event needs to be an instance of Event"
def run(): def run():
""" We dont want the eventbus to be blocking - run in a thread. """ """ We dont want the eventbus to be blocking - run in a thread. """
self.lock.acquire() self.lock.acquire()
self.logger.info("EventBus:Event {}: {}".format(event.event_type, event.data)) self.logger.info("EventBus:Event {}: {}".format(
event.event_type, event.data))
for callback in chain(self.listeners[ALL_EVENTS], self.listeners[event.event_type]): for callback in chain(self.listeners[ALL_EVENTS],
self.listeners[event.event_type]):
try: try:
callback(event) callback(event)
except: except Exception: #pylint: disable=broad-except
self.logger.exception("EventBus:Exception in listener") self.logger.exception("EventBus:Exception in listener")
if event.remove_listener: if event.remove_listener:
@ -139,13 +152,16 @@ class EventBus(object):
def listen(self, event_type, callback): def listen(self, event_type, callback):
""" Listen for all events or events of a specific type. """ Listen for all events or events of a specific type.
To listen to all events specify the constant ``ALL_EVENTS`` as event_type. """ To listen to all events specify the constant ``ALL_EVENTS``
as event_type.
"""
self.lock.acquire() self.lock.acquire()
self.listeners[event_type].append(callback) self.listeners[event_type].append(callback)
self.lock.release() self.lock.release()
# pylint: disable=too-few-public-methods
class Event(object): class Event(object):
""" An event to be sent over the eventbus. """ """ An event to be sent over the eventbus. """
@ -182,7 +198,10 @@ class StateMachine(object):
if old_state.state != new_state: if old_state.state != new_state:
self.states[category] = State(new_state, datetime.now()) self.states[category] = State(new_state, datetime.now())
self.eventbus.fire(Event(EVENT_STATE_CHANGED, {'category':category, 'old_state':old_state, 'new_state':self.states[category]})) self.eventbus.fire(Event(EVENT_STATE_CHANGED,
{'category':category,
'old_state':old_state,
'new_state':self.states[category]}))
self.lock.release() self.lock.release()
@ -193,19 +212,27 @@ class StateMachine(object):
return self.get_state(category).state == state return self.get_state(category).state == state
def get_state(self, category): def get_state(self, category):
""" Returns a tuple (state,last_changed) describing the state of the specified category. """ """ Returns a tuple (state,last_changed) describing
the state of the specified category. """
self._validate_category(category) self._validate_category(category)
return self.states[category] return self.states[category]
def get_states(self): def get_states(self):
""" Returns a list of tuples (category, state, last_changed) sorted by category. """ """ Returns a list of tuples (category, state, last_changed)
return [(category, self.states[category].state, self.states[category].last_changed) for category in sorted(self.states.keys(), key=lambda key: key.lower())] sorted by category case-insensitive. """
return [(category,
self.states[category].state,
self.states[category].last_changed)
for category in
sorted(self.states.keys(), key=lambda key: key.lower())]
def _validate_category(self, category): def _validate_category(self, category):
""" Helper function to throw an exception when the category does not exist. """ """ Helper function to throw an exception
when the category does not exist. """
if category not in self.states: if category not in self.states:
raise CategoryDoesNotExistException("Category {} does not exist.".format(category)) raise CategoryDoesNotExistException(
"Category {} does not exist.".format(category))
class Timer(threading.Thread): class Timer(threading.Thread):
""" Timer will sent out an event every TIMER_INTERVAL seconds. """ """ Timer will sent out an event every TIMER_INTERVAL seconds. """

View File

@ -2,7 +2,8 @@
homeassistant.actors homeassistant.actors
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
This module provides actors that will react to events happening within homeassistant. This module provides actors that will react
to events happening within homeassistant.
""" """
@ -20,13 +21,14 @@ from .packages.pychromecast import pychromecast
from . import track_state_change from . import track_state_change
from .util import sanitize_filename from .util import sanitize_filename
from .observers import (STATE_CATEGORY_SUN, SUN_STATE_BELOW_HORIZON, SUN_STATE_ABOVE_HORIZON, from .observers import (
STATE_CATEGORY_ALL_DEVICES, DEVICE_STATE_HOME, DEVICE_STATE_NOT_HOME, STATE_CATEGORY_SUN, SUN_STATE_BELOW_HORIZON, SUN_STATE_ABOVE_HORIZON,
STATE_CATEGORY_NEXT_SUN_SETTING, track_time_change) STATE_CATEGORY_ALL_DEVICES, DEVICE_STATE_HOME, DEVICE_STATE_NOT_HOME,
STATE_CATEGORY_NEXT_SUN_SETTING, track_time_change)
LIGHT_TRANSITION_TIME = timedelta(minutes=15) LIGHT_TRANSITION_TIME = timedelta(minutes=15)
HUE_MAX_TRANSITION_TIME = 9000 HUE_MAX_TRANSITION_TIME = 9000 # 900 seconds = 15 minutes
EVENT_DOWNLOAD_FILE = "download_file" EVENT_DOWNLOAD_FILE = "download_file"
EVENT_BROWSE_URL = "browse_url" EVENT_BROWSE_URL = "browse_url"
@ -34,13 +36,17 @@ EVENT_CHROMECAST_YOUTUBE_VIDEO = "chromecast.play_youtube_video"
EVENT_TURN_LIGHT_ON = "turn_light_on" EVENT_TURN_LIGHT_ON = "turn_light_on"
EVENT_TURN_LIGHT_OFF = "turn_light_off" EVENT_TURN_LIGHT_OFF = "turn_light_off"
def _hue_process_transition_time(transition_seconds): def _hue_process_transition_time(transition_seconds):
""" Transition time is in 1/10th seconds and cannot exceed MAX_TRANSITION_TIME. """ """ Transition time is in 1/10th seconds
and cannot exceed MAX_TRANSITION_TIME. """
return min(HUE_MAX_TRANSITION_TIME, transition_seconds * 10) return min(HUE_MAX_TRANSITION_TIME, transition_seconds * 10)
# pylint: disable=too-few-public-methods
class LightTrigger(object): class LightTrigger(object):
""" Class to turn on lights based on available devices and state of the sun. """ """ Class to turn on lights based on state of devices and the sun
or triggered by light events. """
def __init__(self, eventbus, statemachine, device_tracker, light_control): def __init__(self, eventbus, statemachine, device_tracker, light_control):
self.eventbus = eventbus self.eventbus = eventbus
@ -51,97 +57,133 @@ class LightTrigger(object):
# Track home coming of each seperate device # Track home coming of each seperate device
for category in device_tracker.device_state_categories(): for category in device_tracker.device_state_categories():
track_state_change(eventbus, category, DEVICE_STATE_NOT_HOME, DEVICE_STATE_HOME, self._handle_device_state_change) track_state_change(eventbus, category,
DEVICE_STATE_NOT_HOME, DEVICE_STATE_HOME,
self._handle_device_state_change)
# Track when all devices are gone to shut down lights # Track when all devices are gone to shut down lights
track_state_change(eventbus, STATE_CATEGORY_ALL_DEVICES, DEVICE_STATE_HOME, DEVICE_STATE_NOT_HOME, self._handle_device_state_change) track_state_change(eventbus, STATE_CATEGORY_ALL_DEVICES,
DEVICE_STATE_HOME, DEVICE_STATE_NOT_HOME,
self._handle_device_state_change)
# Track every time sun rises so we can schedule a time-based pre-sun set event # Track every time sun rises so we can schedule a time-based
track_state_change(eventbus, STATE_CATEGORY_SUN, SUN_STATE_BELOW_HORIZON, SUN_STATE_ABOVE_HORIZON, self._handle_sun_rising) # pre-sun set event
track_state_change(eventbus, STATE_CATEGORY_SUN,
SUN_STATE_BELOW_HORIZON, SUN_STATE_ABOVE_HORIZON,
self._handle_sun_rising)
# If the sun is already above horizon schedule the time-based pre-sun set event # If the sun is already above horizon
# schedule the time-based pre-sun set event
if statemachine.is_state(STATE_CATEGORY_SUN, SUN_STATE_ABOVE_HORIZON): if statemachine.is_state(STATE_CATEGORY_SUN, SUN_STATE_ABOVE_HORIZON):
self._handle_sun_rising(None, None, None) self._handle_sun_rising(None, None, None)
def handle_light_event(event):
""" Hande a turn light on or off event. """
light_id = event.data.get("light_id", None)
transition_seconds = event.data.get("transition_seconds", None)
if event.event_type == EVENT_TURN_LIGHT_ON:
self.light_control.turn_light_on(light_id, transition_seconds)
else:
self.light_control.turn_light_off(light_id, transition_seconds)
# Listen for light on and light off events # Listen for light on and light off events
eventbus.listen(EVENT_TURN_LIGHT_ON, lambda event: self.light_control.turn_light_on(event.data.get("light_id", None), eventbus.listen(EVENT_TURN_LIGHT_ON, handle_light_event)
event.data.get("transition_seconds", None))) eventbus.listen(EVENT_TURN_LIGHT_OFF, handle_light_event)
eventbus.listen(EVENT_TURN_LIGHT_OFF, lambda event: self.light_control.turn_light_off(event.data.get("light_id", None),
event.data.get("transition_seconds", None)))
# pylint: disable=unused-argument
def _handle_sun_rising(self, category, old_state, new_state): def _handle_sun_rising(self, category, old_state, new_state):
"""The moment sun sets we want to have all the lights on. """The moment sun sets we want to have all the lights on.
We will schedule to have each light start after one another We will schedule to have each light start after one another
and slowly transition in.""" and slowly transition in."""
start_point = self._start_point_turn_light_before_sun_set() start_point = self._time_for_light_before_sun_set()
def turn_on(light_id): def turn_on(light_id):
""" Lambda can keep track of function parameters, not from local parameters """ Lambda can keep track of function parameters but not local
If we put the lambda directly in the below statement only the last light parameters. If we put the lambda directly in the below statement
would be turned on.. """ only the last light would be turned on.. """
return lambda now: self._turn_light_on_before_sunset(light_id) return lambda now: self._turn_light_on_before_sunset(light_id)
for index, light_id in enumerate(self.light_control.light_ids): for index, light_id in enumerate(self.light_control.light_ids):
track_time_change(self.eventbus, turn_on(light_id), track_time_change(self.eventbus, turn_on(light_id),
point_in_time=start_point + index * LIGHT_TRANSITION_TIME) point_in_time=start_point +
index * LIGHT_TRANSITION_TIME)
def _turn_light_on_before_sunset(self, light_id=None): def _turn_light_on_before_sunset(self, light_id=None):
""" Helper function to turn on lights slowly if there are devices home and the light is not on yet. """ """ Helper function to turn on lights slowlyif there
if self.statemachine.is_state(STATE_CATEGORY_ALL_DEVICES, DEVICE_STATE_HOME) and not self.light_control.is_light_on(light_id): are devices home and the light is not on yet. """
self.light_control.turn_light_on(light_id, LIGHT_TRANSITION_TIME.seconds) if self.statemachine.is_state(STATE_CATEGORY_ALL_DEVICES,
DEVICE_STATE_HOME) and not self.light_control.is_light_on(light_id):
self.light_control.turn_light_on(light_id,
LIGHT_TRANSITION_TIME.seconds)
def _handle_device_state_change(self, category, old_state, new_state): def _handle_device_state_change(self, category, old_state, new_state):
""" Function to handle tracked device state changes. """ """ Function to handle tracked device state changes. """
lights_are_on = self.light_control.is_light_on() lights_are_on = self.light_control.is_light_on()
light_needed = not lights_are_on and self.statemachine.is_state(STATE_CATEGORY_SUN, SUN_STATE_BELOW_HORIZON) light_needed = (not lights_are_on and
self.statemachine.is_state(STATE_CATEGORY_SUN,
SUN_STATE_BELOW_HORIZON))
# Specific device came home ? # Specific device came home ?
if category != STATE_CATEGORY_ALL_DEVICES and new_state.state == DEVICE_STATE_HOME: if (category != STATE_CATEGORY_ALL_DEVICES and
new_state.state == DEVICE_STATE_HOME):
# These variables are needed for the elif check # These variables are needed for the elif check
now = datetime.now() now = datetime.now()
start_point = self._start_point_turn_light_before_sun_set() start_point = self._time_for_light_before_sun_set()
# Do we need lights? # Do we need lights?
if light_needed: if light_needed:
self.logger.info("Home coming event for {}. Turning lights on".format(category))
self.logger.info(
"Home coming event for {}. Turning lights on".
format(category))
self.light_control.turn_light_on() self.light_control.turn_light_on()
# Are we in the time span were we would turn on the lights if someone would be home? # Are we in the time span were we would turn on the lights
# Check this by seeing if current time is later then the start point # if someone would be home?
# Check this by seeing if current time is later then the point
# in time when we would start putting the lights on.
elif now > start_point and now < self._next_sun_setting(): elif now > start_point and now < self._next_sun_setting():
# If this is the case check for every light if it would be on # Check for every light if it would be on if someone was home
# if someone was home when the fading in started and turn it on # when the fading in started and turn it on if so
for index, light_id in enumerate(self.light_control.light_ids): for index, light_id in enumerate(self.light_control.light_ids):
if now > start_point + index * LIGHT_TRANSITION_TIME: if now > start_point + index * LIGHT_TRANSITION_TIME:
self.light_control.turn_light_on(light_id) self.light_control.turn_light_on(light_id)
else: else:
# If this one was not the case then the following IFs are not True # If this light didn't happen to be turned on yet so
# as their points are even further in time, break # will all the following then, break.
break break
# Did all devices leave the house? # Did all devices leave the house?
elif category == STATE_CATEGORY_ALL_DEVICES and new_state.state == DEVICE_STATE_NOT_HOME and lights_are_on: elif (category == STATE_CATEGORY_ALL_DEVICES and
self.logger.info("Everyone has left but lights are on. Turning lights off") new_state.state == DEVICE_STATE_NOT_HOME and lights_are_on):
self.logger.info(("Everyone has left but lights are on. "
"Turning lights off"))
self.light_control.turn_light_off() self.light_control.turn_light_off()
def _next_sun_setting(self): def _next_sun_setting(self):
""" Returns the datetime object representing the next sun setting. """ """ Returns the datetime object representing the next sun setting. """
return dateutil.parser.parse(self.statemachine.get_state(STATE_CATEGORY_NEXT_SUN_SETTING).state) return dateutil.parser.parse(
self.statemachine.get_state(STATE_CATEGORY_NEXT_SUN_SETTING).state)
def _start_point_turn_light_before_sun_set(self): def _time_for_light_before_sun_set(self):
""" Helper method to calculate the point in time we have to start fading in lights """ Helper method to calculate the point in time we have to start
so that all the lights are on the moment the sun sets. """ fading in lights so that all the lights are on the moment the sun
return self._next_sun_setting() - LIGHT_TRANSITION_TIME * len(self.light_control.light_ids) sets.
"""
return (self._next_sun_setting() -
LIGHT_TRANSITION_TIME * len(self.light_control.light_ids))
class HueLightControl(object): class HueLightControl(object):
@ -169,8 +211,9 @@ class HueLightControl(object):
command = {'on': True, 'xy': [0.5119, 0.4147], 'bri':164} command = {'on': True, 'xy': [0.5119, 0.4147], 'bri':164}
if transition_seconds is not None: if transition_seconds:
command['transitiontime'] = _hue_process_transition_time(transition_seconds) command['transitiontime'] = _hue_process_transition_time(
transition_seconds)
self.bridge.set_light(light_id, command) self.bridge.set_light(light_id, command)
@ -182,8 +225,9 @@ class HueLightControl(object):
command = {'on': False} command = {'on': False}
if transition_seconds is not None: if transition_seconds:
command['transitiontime'] = _hue_process_transition_time(transition_seconds) command['transitiontime'] = _hue_process_transition_time(
transition_seconds)
self.bridge.set_light(light_id, command) self.bridge.set_light(light_id, command)
@ -194,18 +238,24 @@ def setup_file_downloader(eventbus, download_path):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if not os.path.isdir(download_path): if not os.path.isdir(download_path):
logger.error("Download path {} does not exist. File Downloader not active.".format(download_path))
logger.error(
"Download path {} does not exist. File Downloader not active.".
format(download_path))
return return
def download_file(event): def download_file(event):
""" Downloads file specified in the url. """ """ Downloads file specified in the url. """
try: try:
req = requests.get(event.data['url'], stream=True) req = requests.get(event.data['url'], stream=True)
if req.status_code == 200: if req.status_code == 200:
filename = None filename = None
if 'content-disposition' in req.headers: if 'content-disposition' in req.headers:
match = re.findall(r"filename=(\S+)", req.headers['content-disposition']) match = re.findall(r"filename=(\S+)",
req.headers['content-disposition'])
if len(match) > 0: if len(match) > 0:
filename = match[0].strip("'\" ") filename = match[0].strip("'\" ")
@ -219,36 +269,47 @@ def setup_file_downloader(eventbus, download_path):
# Remove stuff to ruin paths # Remove stuff to ruin paths
filename = sanitize_filename(filename) filename = sanitize_filename(filename)
path, ext = os.path.splitext(os.path.join(download_path, filename)) path, ext = os.path.splitext(os.path.join(download_path,
filename))
# If file exist append a number. We test filename, filename_2, filename_3 etc.. # If file exist append a number. We test filename, filename_2..
tries = 0 tries = 0
while True: while True:
tries += 1 tries += 1
final_path = path + ("" if tries == 1 else "_{}".format(tries)) + ext name_suffix = "" if tries == 1 else "_{}".format(tries)
final_path = path + name_suffix + ext
if not os.path.isfile(final_path): if not os.path.isfile(final_path):
break break
logger.info("FileDownloader:{} -> {}".format(event.data['url'], final_path)) logger.info("FileDownloader:{} -> {}".format(
event.data['url'], final_path))
with open(final_path, 'wb') as fil: with open(final_path, 'wb') as fil:
for chunk in req.iter_content(1024): for chunk in req.iter_content(1024):
fil.write(chunk) fil.write(chunk)
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.exception("FileDownloader:ConnectionError occured for {}".format(event.data['url'])) logger.exception("FileDownloader:ConnectionError occured for {}".
format(event.data['url']))
eventbus.listen(EVENT_DOWNLOAD_FILE, download_file) eventbus.listen(EVENT_DOWNLOAD_FILE, download_file)
def setup_webbrowser(eventbus): def setup_webbrowser(eventbus):
""" Listen for browse_url events and opens the url in the default webbrowser. """ """ Listen for browse_url events and open
eventbus.listen(EVENT_BROWSE_URL, lambda event: webbrowser.open(event.data['url'])) the url in the default webbrowser. """
eventbus.listen(EVENT_BROWSE_URL,
lambda event: webbrowser.open(event.data['url']))
def setup_chromecast(eventbus, host): def setup_chromecast(eventbus, host):
""" Listen for chromecast events. """ """ Listen for chromecast events. """
eventbus.listen("start_fireplace", lambda event: pychromecast.play_youtube_video(host, "eyU3bRy2x44")) eventbus.listen("start_fireplace",
eventbus.listen("start_epic_sax", lambda event: pychromecast.play_youtube_video(host, "kxopViU98Xo")) lambda event: pychromecast.play_youtube_video(host, "eyU3bRy2x44"))
eventbus.listen(EVENT_CHROMECAST_YOUTUBE_VIDEO, lambda event: pychromecast.play_youtube_video(host, event.data['video']))
eventbus.listen("start_epic_sax",
lambda event: pychromecast.play_youtube_video(host, "kxopViU98Xo"))
eventbus.listen(EVENT_CHROMECAST_YOUTUBE_VIDEO,
lambda event: pychromecast.play_youtube_video(host, event.data['video']))

View File

@ -44,11 +44,16 @@ MESSAGE_STATUS_UNAUTHORIZED = "UNAUTHORIZED"
class HTTPInterface(threading.Thread): class HTTPInterface(threading.Thread):
""" Provides an HTTP interface for Home Assistant. """ """ Provides an HTTP interface for Home Assistant. """
def __init__(self, eventbus, statemachine, api_password, server_port=SERVER_PORT, server_host=None): # pylint: disable=too-many-arguments
def __init__(self, eventbus, statemachine, api_password,
server_port=None, server_host=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
if not server_port:
server_port = SERVER_PORT
# If no server host is given, accept all incoming requests # If no server host is given, accept all incoming requests
if server_host is None: if not server_host:
server_host = '0.0.0.0' server_host = '0.0.0.0'
self.server = HTTPServer((server_host, server_port), RequestHandler) self.server = HTTPServer((server_host, server_port), RequestHandler)
@ -78,7 +83,8 @@ class HTTPInterface(threading.Thread):
# Trigger a fake request to get the server to quit # Trigger a fake request to get the server to quit
try: try:
requests.get("http://127.0.0.1:{}".format(SERVER_PORT), timeout=0.001) requests.get("http://127.0.0.1:{}".format(SERVER_PORT),
timeout=0.001)
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
pass pass
@ -86,7 +92,7 @@ class RequestHandler(BaseHTTPRequestHandler):
""" Handles incoming HTTP requests """ """ Handles incoming HTTP requests """
#Handler for the GET requests #Handler for the GET requests
def do_GET(self): def do_GET(self): # pylint: disable=invalid-name
""" Handle incoming GET requests. """ """ Handle incoming GET requests. """
write = lambda txt: self.wfile.write(txt+"\n") write = lambda txt: self.wfile.write(txt+"\n")
@ -94,19 +100,21 @@ class RequestHandler(BaseHTTPRequestHandler):
get_data = parse_qs(url.query) get_data = parse_qs(url.query)
api_password = get_data.get('api_password', [''])[0]
if url.path == "/": if url.path == "/":
if self._verify_api_password(get_data.get('api_password', [''])[0], False): if self._verify_api_password(api_password, False):
self.send_response(200) self.send_response(200)
self.send_header('Content-type','text/html') self.send_header('Content-type','text/html')
self.end_headers() self.end_headers()
write("<html>") write(("<html>"
write("<head><title>Home Assistant</title></head>") "<head><title>Home Assistant</title></head>"
write("<body>") "<body>"))
# Flash message support # Flash message support
if self.server.flash_message is not None: if self.server.flash_message:
write("<h3>{}</h3>".format(self.server.flash_message)) write("<h3>{}</h3>".format(self.server.flash_message))
self.server.flash_message = None self.server.flash_message = None
@ -114,20 +122,28 @@ class RequestHandler(BaseHTTPRequestHandler):
# Describe state machine: # Describe state machine:
categories = [] categories = []
write("<table>") write(("<table><tr>"
write("<tr><th>Name</th><th>State</th><th>Last Changed</th></tr>") "<th>Name</th><th>State</th>"
"<th>Last Changed</th></tr>"))
for category, state, last_changed in \
self.server.statemachine.get_states():
for category, state, last_changed in self.server.statemachine.get_states():
categories.append(category) categories.append(category)
write("<tr><td>{}</td><td>{}</td><td>{}</td></tr>".format(category, state, last_changed.strftime("%H:%M:%S %d-%m-%Y"))) write("<tr><td>{}</td><td>{}</td><td>{}</td></tr>".
format(category, state,
last_changed.strftime("%H:%M:%S %d-%m-%Y")))
write("</table>") write("</table>")
# Small form to change the state # Small form to change the state
write("<br />Change state:<br />") write(("<br />Change state:<br />"
write("<form action='state/change' method='POST'>") "<form action='state/change' method='POST'>"))
write("<input type='hidden' name='api_password' value='{}' />".format(self.server.api_password))
write("<input type='hidden' name='api_password' value='{}' />".
format(self.server.api_password))
write("<select name='category'>") write("<select name='category'>")
for category in categories: for category in categories:
@ -135,22 +151,26 @@ class RequestHandler(BaseHTTPRequestHandler):
write("</select>") write("</select>")
write("<input name='new_state' />") write(("<input name='new_state' />"
write("<input type='submit' value='set state' />") "<input type='submit' value='set state' />"
write("</form>") "</form>"))
# Describe event bus: # Describe event bus:
for category in self.server.eventbus.listeners: for category in self.server.eventbus.listeners:
write("Event {}: {} listeners<br />".format(category, len(self.server.eventbus.listeners[category]))) write("Event {}: {} listeners<br />".format(category,
len(self.server.eventbus.listeners[category])))
# Form to allow firing events # Form to allow firing events
write("<br /><br />") write(("<br />"
write("<form action='event/fire' method='POST'>") "<form action='event/fire' method='POST'>"))
write("<input type='hidden' name='api_password' value='{}' />".format(self.server.api_password))
write("Event name: <input name='event_name' /><br />") write("<input type='hidden' name='api_password' value='{}' />".
write("Event data (json): <input name='event_data' /><br />") format(self.server.api_password))
write("<input type='submit' value='fire event' />")
write("</form>") write(("Event name: <input name='event_name' /><br />"
"Event data (json): <input name='event_data' /><br />"
"<input type='submit' value='fire event' />"
"</form>"))
write("</body></html>") write("</body></html>")
@ -158,8 +178,7 @@ class RequestHandler(BaseHTTPRequestHandler):
else: else:
self.send_response(404) self.send_response(404)
def do_POST(self): # pylint: disable=invalid-name, too-many-branches
def do_POST(self):
""" Handle incoming POST requests. """ """ Handle incoming POST requests. """
length = int(self.headers['Content-Length']) length = int(self.headers['Content-Length'])
@ -181,46 +200,63 @@ class RequestHandler(BaseHTTPRequestHandler):
try: try:
changed = [] changed = []
for category, new_state in zip(post_data['category'], post_data['new_state']): for category, new_state in zip(post_data['category'],
post_data['new_state']):
self.server.statemachine.set_state(category, new_state) self.server.statemachine.set_state(category, new_state)
changed.append("{}={}".format(category, new_state)) changed.append("{}={}".format(category, new_state))
self._message(use_json, "States changed: {}".format( ", ".join(changed) ) ) self._message(use_json, "States changed: {}".
format( ", ".join(changed) ) )
except KeyError: except KeyError:
# If category or new_state don't exist in post data # If category or new_state don't exist in post data
self._message(use_json, "Invalid state received.", MESSAGE_STATUS_ERROR) self._message(use_json, "Invalid state received.",
MESSAGE_STATUS_ERROR)
# Action to fire an event # Action to fire an event
elif action == "event/fire": elif action == "event/fire":
if self._verify_api_password(given_api_password, use_json): if self._verify_api_password(given_api_password, use_json):
try: try:
event_name = post_data['event_name'][0] event_name = post_data['event_name'][0]
event_data = None if 'event_data' not in post_data or post_data['event_data'][0] == "" else json.loads(post_data['event_data'][0])
if (not 'event_data' in post_data or
post_data['event_data'][0] == ""):
event_data = None
else:
event_data = json.loads(post_data['event_data'][0])
self.server.eventbus.fire(Event(event_name, event_data)) self.server.eventbus.fire(Event(event_name, event_data))
self._message(use_json, "Event {} fired.".format(event_name)) self._message(use_json, "Event {} fired.".
format(event_name))
except ValueError: except ValueError:
# If JSON decode error # If JSON decode error
self._message(use_json, "Invalid event received (1).", MESSAGE_STATUS_ERROR) self._message(use_json, "Invalid event received (1).",
MESSAGE_STATUS_ERROR)
except KeyError: except KeyError:
# If "event_name" not in post_data # If "event_name" not in post_data
self._message(use_json, "Invalid event received (2).", MESSAGE_STATUS_ERROR) self._message(use_json, "Invalid event received (2).",
MESSAGE_STATUS_ERROR)
else: else:
self.send_response(404) self.send_response(404)
def _verify_api_password(self, api_password, use_json): def _verify_api_password(self, api_password, use_json):
""" Helper method to verify the API password and take action if incorrect. """ """ Helper method to verify the API password
and take action if incorrect. """
if api_password == self.server.api_password: if api_password == self.server.api_password:
return True return True
elif use_json: elif use_json:
self._message(True, "API password missing or incorrect.", MESSAGE_STATUS_UNAUTHORIZED) self._message(True, "API password missing or incorrect.",
MESSAGE_STATUS_UNAUTHORIZED)
else: else:
self.send_response(200) self.send_response(200)
@ -229,14 +265,14 @@ class RequestHandler(BaseHTTPRequestHandler):
write = lambda txt: self.wfile.write(txt+"\n") write = lambda txt: self.wfile.write(txt+"\n")
write("<html>") write(("<html>"
write("<head><title>Home Assistant</title></head>") "<head><title>Home Assistant</title></head>"
write("<body>") "<body>"
write("<form action='/' method='GET'>") "<form action='/' method='GET'>"
write("API password: <input name='api_password' />") "API password: <input name='api_password' />"
write("<input type='submit' value='submit' />") "<input type='submit' value='submit' />"
write("</form>") "</form>"
write("</body></html>") "</body></html>"))
return False return False
@ -250,7 +286,8 @@ class RequestHandler(BaseHTTPRequestHandler):
else: else:
self.server.logger.error(log_message) self.server.logger.error(log_message)
response_code = 401 if status == MESSAGE_STATUS_UNAUTHORIZED else 400 response_code = (401 if status == MESSAGE_STATUS_UNAUTHORIZED
else 400)
if use_json: if use_json:
self.send_response(response_code) self.send_response(response_code)
@ -263,5 +300,6 @@ class RequestHandler(BaseHTTPRequestHandler):
self.server.flash_message = message self.server.flash_message = message
self.send_response(301) self.send_response(301)
self.send_header("Location", "/?api_password={}".format(self.server.api_password)) self.send_header("Location", "/?api_password={}".
format(self.server.api_password))
self.end_headers() self.end_headers()

View File

@ -46,11 +46,12 @@ KNOWN_DEVICES_FILE = "known_devices.csv"
def track_sun(eventbus, statemachine, latitude, longitude): def track_sun(eventbus, statemachine, latitude, longitude):
""" Tracks the state of the sun. """ """ Tracks the state of the sun. """
sun = ephem.Sun() sun = ephem.Sun() # pylint: disable=no-member
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_sun_state(now): def update_sun_state(now): # pylint: disable=unused-argument
""" Method to update the current state of the sun and time the next update. """ """ Method to update the current state of the sun and
set time of next setting and rising. """
observer = ephem.Observer() observer = ephem.Observer()
observer.lat = latitude observer.lat = latitude
observer.long = longitude observer.long = longitude
@ -66,14 +67,18 @@ def track_sun(eventbus, statemachine, latitude, longitude):
new_state = SUN_STATE_BELOW_HORIZON new_state = SUN_STATE_BELOW_HORIZON
next_change = next_rising next_change = next_rising
logger.info("Sun:{}. Next change: {}".format(new_state, next_change.strftime("%H:%M"))) logger.info("Sun:{}. Next change: {}".
format(new_state, next_change.strftime("%H:%M")))
statemachine.set_state(STATE_CATEGORY_SUN, new_state) statemachine.set_state(STATE_CATEGORY_SUN, new_state)
statemachine.set_state(STATE_CATEGORY_NEXT_SUN_RISING, next_rising.isoformat()) statemachine.set_state(STATE_CATEGORY_NEXT_SUN_RISING,
statemachine.set_state(STATE_CATEGORY_NEXT_SUN_SETTING, next_setting.isoformat()) next_rising.isoformat())
statemachine.set_state(STATE_CATEGORY_NEXT_SUN_SETTING,
next_setting.isoformat())
# +10 seconds to be sure that the change has occured # +10 seconds to be sure that the change has occured
track_time_change(eventbus, update_sun_state, point_in_time=next_change + timedelta(seconds=10)) track_time_change(eventbus, update_sun_state,
point_in_time=next_change + timedelta(seconds=10))
update_sun_state(None) update_sun_state(None)
@ -110,47 +115,62 @@ class DeviceTracker(object):
row['track'] = True if row['track'] == '1' else False row['track'] = True if row['track'] == '1' else False
self.known_devices[device] = row
# If we track this device setup tracking variables # If we track this device setup tracking variables
if row['track']: if row['track']:
self.known_devices[device]['last_seen'] = default_last_seen row['last_seen'] = default_last_seen
# Make sure that each device is mapped to a unique category name # Make sure that each device is mapped
name = row['name'] if row['name'] else "unnamed_device" # to a unique category name
name = row['name']
if not name:
name = "unnamed_device"
tries = 0 tries = 0
suffix = ""
while True: while True:
tries += 1 tries += 1
category = STATE_CATEGORY_DEVICE_FORMAT.format(name if tries == 1 else "{}_{}".format(name, tries)) if tries > 1:
suffix = "_{}".format(tries)
category = STATE_CATEGORY_DEVICE_FORMAT.format(
name + suffix)
if category not in used_categories: if category not in used_categories:
break break
self.known_devices[device]['category'] = category row['category'] = category
used_categories.append(category) used_categories.append(category)
self.known_devices[device] = row
except KeyError: except KeyError:
self.invalid_known_devices_file = False self.invalid_known_devices_file = False
self.logger.warning("Invalid {} found. We won't update it with new found devices.".format(KNOWN_DEVICES_FILE)) self.logger.warning(("Invalid {} found. "
"We won't update it with new found devices.").
format(KNOWN_DEVICES_FILE))
if len(self.device_state_categories()) == 0: if len(self.device_state_categories()) == 0:
self.logger.warning("No devices to track. Please update {}.".format(KNOWN_DEVICES_FILE)) self.logger.warning("No devices to track. Please update {}.".
format(KNOWN_DEVICES_FILE))
track_time_change(eventbus, lambda time: self.update_devices(device_scanner.scan_devices())) track_time_change(eventbus,
lambda time: self.update_devices(device_scanner.scan_devices()))
def device_state_categories(self): def device_state_categories(self):
""" Returns a list containing all categories that are maintained for devices. """ """ Returns a list containing all categories
return [self.known_devices[device]['category'] for device in self.known_devices if self.known_devices[device]['track']] that are maintained for devices. """
return [self.known_devices[device]['category'] for device
in self.known_devices if self.known_devices[device]['track']]
def update_devices(self, found_devices): def update_devices(self, found_devices):
""" Keep track of devices that are home, all that are not will be marked not home. """ """ Update device states based on the found devices. """
self.lock.acquire() self.lock.acquire()
temp_tracking_devices = [device for device in self.known_devices if self.known_devices[device]['track']] temp_tracking_devices = [device for device in self.known_devices
if self.known_devices[device]['track']]
for device in found_devices: for device in found_devices:
# Are we tracking this device? # Are we tracking this device?
@ -158,65 +178,92 @@ class DeviceTracker(object):
temp_tracking_devices.remove(device) temp_tracking_devices.remove(device)
self.known_devices[device]['last_seen'] = datetime.now() self.known_devices[device]['last_seen'] = datetime.now()
self.statemachine.set_state(self.known_devices[device]['category'], DEVICE_STATE_HOME)
self.statemachine.set_state(
self.known_devices[device]['category'], DEVICE_STATE_HOME)
# For all devices we did not find, set state to NH # For all devices we did not find, set state to NH
# But only if they have been gone for longer then the error time span # But only if they have been gone for longer then the error time span
# Because we do not want to have stuff happening when the device does # Because we do not want to have stuff happening when the device does
# not show up for 1 scan beacuse of reboot etc # not show up for 1 scan beacuse of reboot etc
for device in temp_tracking_devices: for device in temp_tracking_devices:
if datetime.now() - self.known_devices[device]['last_seen'] > TIME_SPAN_FOR_ERROR_IN_SCANNING: if (datetime.now() - self.known_devices[device]['last_seen'] >
self.statemachine.set_state(self.known_devices[device]['category'], DEVICE_STATE_NOT_HOME) TIME_SPAN_FOR_ERROR_IN_SCANNING):
self.statemachine.set_state(
self.known_devices[device]['category'],
DEVICE_STATE_NOT_HOME)
# Get the currently used statuses # Get the currently used statuses
states_of_devices = [self.statemachine.get_state(category).state for category in self.device_state_categories()] states_of_devices = [self.statemachine.get_state(category).state
for category in self.device_state_categories()]
# Update the all devices category # Update the all devices category
all_devices_state = DEVICE_STATE_HOME if DEVICE_STATE_HOME in states_of_devices else DEVICE_STATE_NOT_HOME all_devices_state = (DEVICE_STATE_HOME if DEVICE_STATE_HOME
in states_of_devices else DEVICE_STATE_NOT_HOME)
self.statemachine.set_state(STATE_CATEGORY_ALL_DEVICES, all_devices_state) self.statemachine.set_state(STATE_CATEGORY_ALL_DEVICES,
all_devices_state)
# If we come along any unknown devices we will write them to the known devices file # If we come along any unknown devices we will write them to the
# but only if we did not encounter an invalid known devices file # known devices file but only if we did not encounter an invalid
# known devices file
if not self.invalid_known_devices_file: if not self.invalid_known_devices_file:
unknown_devices = [device for device in found_devices if device not in self.known_devices]
unknown_devices = [device for device in found_devices
if device not in self.known_devices]
if len(unknown_devices) > 0: if len(unknown_devices) > 0:
try: try:
# If file does not exist we will write the header too # If file does not exist we will write the header too
should_write_header = not os.path.isfile(KNOWN_DEVICES_FILE) is_new_file = not os.path.isfile(KNOWN_DEVICES_FILE)
with open(KNOWN_DEVICES_FILE, 'a') as outp: with open(KNOWN_DEVICES_FILE, 'a') as outp:
self.logger.info("DeviceTracker:Found {} new devices, updating {}".format(len(unknown_devices), KNOWN_DEVICES_FILE)) self.logger.info(("DeviceTracker:Found {} new devices,"
" updating {}").format(len(unknown_devices),
KNOWN_DEVICES_FILE))
writer = csv.writer(outp) writer = csv.writer(outp)
if should_write_header: if is_new_file:
writer.writerow(("device", "name", "track")) writer.writerow(("device", "name", "track"))
for device in unknown_devices: for device in unknown_devices:
# See if the device scanner knows the name # See if the device scanner knows the name
temp_name = self.device_scanner.get_device_name(device) temp_name = self.device_scanner.get_device_name(
device)
name = temp_name if temp_name else "unknown_device" name = temp_name if temp_name else "unknown_device"
writer.writerow((device, name, 0)) writer.writerow((device, name, 0))
self.known_devices[device] = {'name':name, 'track': False} self.known_devices[device] = {'name':name,
'track': False}
except IOError: except IOError:
self.logger.exception("DeviceTracker:Error updating {} with {} new devices".format(KNOWN_DEVICES_FILE, len(unknown_devices))) self.logger.exception(("DeviceTracker:Error updating {}"
"with {} new devices").format(KNOWN_DEVICES_FILE,
len(unknown_devices)))
self.lock.release() self.lock.release()
class TomatoDeviceScanner(object): class TomatoDeviceScanner(object):
""" This class queries a wireless router running Tomato firmware for connected devices. """ This class queries a wireless router running Tomato firmware
for connected devices.
A description of the Tomato API can be found on A description of the Tomato API can be found on
http://paulusschoutsen.nl/blog/2013/10/tomato-api-documentation/ """ http://paulusschoutsen.nl/blog/2013/10/tomato-api-documentation/
"""
def __init__(self, host, username, password, http_id): def __init__(self, host, username, password, http_id):
self.req = requests.Request('POST', 'http://{}/update.cgi'.format(host), self.req = requests.Request('POST',
data={'_http_id':http_id, 'exec':'devlist'}, 'http://{}/update.cgi'.format(host),
auth=requests.auth.HTTPBasicAuth(username, password)).prepare() data={'_http_id': http_id,
'exec': 'devlist'},
auth=requests.auth.HTTPBasicAuth(
username, password)).prepare()
self.parse_api_pattern = re.compile(r"(?P<param>\w*) = (?P<value>.*);")
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.lock = threading.Lock() self.lock = threading.Lock()
@ -225,7 +272,8 @@ class TomatoDeviceScanner(object):
self.last_results = {"wldev": [], "dhcpd_lease": []} self.last_results = {"wldev": [], "dhcpd_lease": []}
def scan_devices(self): def scan_devices(self):
""" Scans for new devices and returns a list containing found device ids. """ """ Scans for new devices and return a
list containing found device ids. """
self._update_tomato_info() self._update_tomato_info()
@ -238,9 +286,14 @@ class TomatoDeviceScanner(object):
if not self.date_updated: if not self.date_updated:
self._update_tomato_info() self._update_tomato_info()
filter_named = [item[0] for item in self.last_results['dhcpd_lease'] if item[2] == device] filter_named = [item[0] for item in self.last_results['dhcpd_lease']
if item[2] == device]
return None if len(filter_named) == 0 or filter_named[0] == "" else filter_named[0]
if len(filter_named) == 0 or filter_named[0] == "":
return None
else:
return filter_named[0]
def _update_tomato_info(self): def _update_tomato_info(self):
""" Ensures the information from the Tomato router is up to date. """ Ensures the information from the Tomato router is up to date.
@ -249,37 +302,48 @@ class TomatoDeviceScanner(object):
self.lock.acquire() self.lock.acquire()
# if date_updated is None or the date is too old we scan for new data # if date_updated is None or the date is too old we scan for new data
if not self.date_updated or datetime.now() - self.date_updated > TOMATO_MIN_TIME_BETWEEN_SCANS: if (not self.date_updated or datetime.now() - self.date_updated >
TOMATO_MIN_TIME_BETWEEN_SCANS):
self.logger.info("Tomato:Scanning") self.logger.info("Tomato:Scanning")
try: try:
response = requests.Session().send(self.req) response = requests.Session().send(self.req)
# Calling and parsing the Tomato api here. We only need the wldev and dhcpd_lease values. # Calling and parsing the Tomato api here. We only need the
# See http://paulusschoutsen.nl/blog/2013/10/tomato-api-documentation/ for what's going on here. # wldev and dhcpd_lease values. For API description see:
# http://paulusschoutsen.nl/blog/2013/10/tomato-api-documentation/
if response.status_code == 200: if response.status_code == 200:
self.last_results = {param: json.loads(value.replace("'",'"'))
for param, value in re.findall(r"(?P<param>\w*) = (?P<value>.*);", response.text) for param, value in self.parse_api_pattern.findall(
if param in ["wldev","dhcpd_lease"]} response.text):
if param == 'wldev' or param == 'dhcpd_lease':
self.last_results[param] = json.loads(value.
replace("'",'"'))
self.date_updated = datetime.now() self.date_updated = datetime.now()
elif response.status_code == 401: elif response.status_code == 401:
# Authentication error # Authentication error
self.logger.exception("Tomato:Failed to authenticate, please check your username and password") self.logger.exception(("Tomato:Failed to authenticate, "
"please check your username and password"))
except requests.ConnectionError: except requests.ConnectionError:
# We get this if we could not connect to the router or an invalid http_id was supplied # We get this if we could not connect to the router or
self.logger.exception("Tomato:Failed to connect to the router or invalid http_id supplied") # an invalid http_id was supplied
self.logger.exception(("Tomato:Failed to connect to the router"
"or invalid http_id supplied"))
except ValueError: except ValueError:
# If json decoder could not parse the response # If json decoder could not parse the response
self.logger.exception("Tomato:Failed to parse response from router") self.logger.exception(("Tomato:Failed to parse response "
"from router"))
finally: finally:
self.lock.release() self.lock.release()
else: else:
# We acquired the lock before the IF check, release it before we return True # We acquired the lock before the IF check,
# release it before we return True
self.lock.release() self.lock.release()

View File

@ -0,0 +1,5 @@
"""
Not all external Git repositories that we depend on are
available as a package for pip. That is why we include
them here.
"""

@ -1 +1 @@
Subproject commit 6b8999574c8f70cb28686ef8d19f1e4bf0c8c056 Subproject commit 22af2589b840991220788b5e93921b89433cd02e

View File

@ -19,39 +19,13 @@ API_PASSWORD = "test1234"
HTTP_BASE_URL = "http://127.0.0.1:{}".format(SERVER_PORT) HTTP_BASE_URL = "http://127.0.0.1:{}".format(SERVER_PORT)
# pylint: disable=too-many-public-methods
class HomeAssistantTestCase(unittest.TestCase): class TestHTTPInterface(unittest.TestCase):
""" Base class for Home Assistant test cases. """
@classmethod
def setUpClass(cls):
cls.eventbus = EventBus()
cls.statemachine = StateMachine(cls.eventbus)
cls.init_ha = False
def start_ha(self):
""" Classes will have to call this from setUp()
after initializing their components. """
cls.eventbus.fire(Event(EVENT_START))
# Give objects time to startup
time.sleep(1)
cls.start_ha = start_ha
@classmethod
def tearDownClass(cls):
cls.eventbus.fire(Event(EVENT_SHUTDOWN))
time.sleep(1)
class TestHTTPInterface(HomeAssistantTestCase):
""" Test the HTTP debug interface and API. """ """ Test the HTTP debug interface and API. """
HTTP_init = False HTTP_init = False
def setUp(self): def setUp(self): # pylint: disable=invalid-name
""" Initialize the HTTP interface if not started yet. """ """ Initialize the HTTP interface if not started yet. """
if not TestHTTPInterface.HTTP_init: if not TestHTTPInterface.HTTP_init:
TestHTTPInterface.HTTP_init = True TestHTTPInterface.HTTP_init = True
@ -60,31 +34,52 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.statemachine.set_state("test", "INIT_STATE") self.statemachine.set_state("test", "INIT_STATE")
self.start_ha() self.eventbus.fire(Event(EVENT_START))
# Give objects time to startup
time.sleep(1)
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
""" things to be run when tests are started. """
cls.eventbus = EventBus()
cls.statemachine = StateMachine(cls.eventbus)
@classmethod
def tearDownClass(cls): # pylint: disable=invalid-name
""" things to be run when tests are done. """
cls.eventbus.fire(Event(EVENT_SHUTDOWN))
time.sleep(1)
def test_debug_interface(self): def test_debug_interface(self):
""" Test if we can login by comparing not logged in screen to logged in screen. """ """ Test if we can login by comparing not logged in screen to
logged in screen. """
self.assertNotEqual(requests.get(HTTP_BASE_URL).text, self.assertNotEqual(requests.get(HTTP_BASE_URL).text,
requests.get("{}/?api_password={}".format(HTTP_BASE_URL, API_PASSWORD)).text) requests.get("{}/?api_password={}".format(
HTTP_BASE_URL, API_PASSWORD)).text)
def test_debug_state_change(self): def test_debug_state_change(self):
""" Test if the debug interface allows us to change a state. """ """ Test if the debug interface allows us to change a state. """
requests.post("{}/state/change".format(HTTP_BASE_URL), data={"category":"test", requests.post("{}/state/change".format(HTTP_BASE_URL),
"new_state":"debug_state_change", data={"category":"test",
"api_password":API_PASSWORD}) "new_state":"debug_state_change",
"api_password":API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test").state, "debug_state_change") self.assertEqual(self.statemachine.get_state("test").state,
"debug_state_change")
def test_api_password(self): def test_api_password(self):
""" Test if we get access denied if we omit or provide a wrong api password. """ """ Test if we get access denied if we omit or provide
a wrong api password. """
req = requests.post("{}/api/state/change".format(HTTP_BASE_URL)) req = requests.post("{}/api/state/change".format(HTTP_BASE_URL))
self.assertEqual(req.status_code, 401) self.assertEqual(req.status_code, 401)
req = requests.post("{}/api/state/change".format(HTTP_BASE_URL, data={"api_password":"not the password"})) req = requests.post("{}/api/state/change".format(HTTP_BASE_URL,
data={"api_password":"not the password"}))
self.assertEqual(req.status_code, 401) self.assertEqual(req.status_code, 401)
@ -94,11 +89,13 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.statemachine.set_state("test", "not_to_be_set_state") self.statemachine.set_state("test", "not_to_be_set_state")
requests.post("{}/api/state/change".format(HTTP_BASE_URL), data={"category":"test", requests.post("{}/api/state/change".format(HTTP_BASE_URL),
"new_state":"debug_state_change2", data={"category":"test",
"api_password":API_PASSWORD}) "new_state":"debug_state_change2",
"api_password":API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test").state, "debug_state_change2") self.assertEqual(self.statemachine.get_state("test").state,
"debug_state_change2")
def test_api_multiple_state_change(self): def test_api_multiple_state_change(self):
""" Test if we can change multiple states in 1 request. """ """ Test if we can change multiple states in 1 request. """
@ -106,36 +103,49 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.statemachine.set_state("test", "not_to_be_set_state") self.statemachine.set_state("test", "not_to_be_set_state")
self.statemachine.set_state("test2", "not_to_be_set_state") self.statemachine.set_state("test2", "not_to_be_set_state")
requests.post("{}/api/state/change".format(HTTP_BASE_URL), data={"category": ["test", "test2"], requests.post("{}/api/state/change".format(HTTP_BASE_URL),
"new_state": ["test_state_1", "test_state_2"], data={"category": ["test", "test2"],
"api_password":API_PASSWORD}) "new_state": ["test_state_1", "test_state_2"],
"api_password":API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test").state, "test_state_1") self.assertEqual(self.statemachine.get_state("test").state,
self.assertEqual(self.statemachine.get_state("test2").state, "test_state_2") "test_state_1")
self.assertEqual(self.statemachine.get_state("test2").state,
"test_state_2")
# pylint: disable=invalid-name
def test_api_state_change_of_non_existing_category(self): def test_api_state_change_of_non_existing_category(self):
""" Test if the API allows us to change a state of a non existing category. """ """ Test if the API allows us to change a state of
a non existing category. """
req = requests.post("{}/api/state/change".format(HTTP_BASE_URL), data={"category":"test_category_that_does_not_exist", new_state = "debug_state_change"
"new_state":"debug_state_change",
"api_password":API_PASSWORD}) req = requests.post("{}/api/state/change".format(HTTP_BASE_URL),
data={"category":"test_category_that_does_not_exist",
"new_state":new_state,
"api_password":API_PASSWORD})
cur_state = (self.statemachine.
get_state("test_category_that_does_not_exist").state)
self.assertEqual(req.status_code, 200) self.assertEqual(req.status_code, 200)
self.assertEqual(self.statemachine.get_state("test_category_that_does_not_exist").state, "debug_state_change") self.assertEqual(cur_state, new_state)
# pylint: disable=invalid-name
def test_api_fire_event_with_no_data(self): def test_api_fire_event_with_no_data(self):
""" Test if the API allows us to fire an event. """ """ Test if the API allows us to fire an event. """
test_value = [] test_value = []
def listener(event): def listener(event): # pylint: disable=unused-argument
""" Helper method that will verify our event got called. """ """ Helper method that will verify our event got called. """
test_value.append(1) test_value.append(1)
self.eventbus.listen("test_event_no_data", listener) self.eventbus.listen("test_event_no_data", listener)
requests.post("{}/api/event/fire".format(HTTP_BASE_URL), data={"event_name":"test_event_no_data", requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
"event_data":"", data={"event_name":"test_event_no_data",
"api_password":API_PASSWORD}) "event_data":"",
"api_password":API_PASSWORD})
# Allow the event to take place # Allow the event to take place
time.sleep(1) time.sleep(1)
@ -143,11 +153,12 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.assertEqual(len(test_value), 1) self.assertEqual(len(test_value), 1)
# pylint: disable=invalid-name
def test_api_fire_event_with_data(self): def test_api_fire_event_with_data(self):
""" Test if the API allows us to fire an event. """ """ Test if the API allows us to fire an event. """
test_value = [] test_value = []
def listener(event): def listener(event): # pylint: disable=unused-argument
""" Helper method that will verify that our event got called and """ Helper method that will verify that our event got called and
that test if our data came through. """ that test if our data came through. """
if "test" in event.data: if "test" in event.data:
@ -155,9 +166,10 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.eventbus.listen("test_event_with_data", listener) self.eventbus.listen("test_event_with_data", listener)
requests.post("{}/api/event/fire".format(HTTP_BASE_URL), data={"event_name":"test_event_with_data", requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
"event_data":'{"test": 1}', data={"event_name":"test_event_with_data",
"api_password":API_PASSWORD}) "event_data":'{"test": 1}',
"api_password":API_PASSWORD})
# Allow the event to take place # Allow the event to take place
time.sleep(1) time.sleep(1)
@ -165,6 +177,7 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.assertEqual(len(test_value), 1) self.assertEqual(len(test_value), 1)
# pylint: disable=invalid-name
def test_api_fire_event_with_no_params(self): def test_api_fire_event_with_no_params(self):
""" Test how the API respsonds when we specify no event attributes. """ """ Test how the API respsonds when we specify no event attributes. """
test_value = [] test_value = []
@ -177,7 +190,8 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.eventbus.listen("test_event_with_data", listener) self.eventbus.listen("test_event_with_data", listener)
requests.post("{}/api/event/fire".format(HTTP_BASE_URL), data={"api_password":API_PASSWORD}) requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
data={"api_password":API_PASSWORD})
# Allow the event to take place # Allow the event to take place
time.sleep(1) time.sleep(1)
@ -185,19 +199,21 @@ class TestHTTPInterface(HomeAssistantTestCase):
self.assertEqual(len(test_value), 0) self.assertEqual(len(test_value), 0)
# pylint: disable=invalid-name
def test_api_fire_event_with_invalid_json(self): def test_api_fire_event_with_invalid_json(self):
""" Test if the API allows us to fire an event. """ """ Test if the API allows us to fire an event. """
test_value = [] test_value = []
def listener(event): def listener(event): # pylint: disable=unused-argument
""" Helper method that will verify our event got called. """ """ Helper method that will verify our event got called. """
test_value.append(1) test_value.append(1)
self.eventbus.listen("test_event_with_bad_data", listener) self.eventbus.listen("test_event_with_bad_data", listener)
req = requests.post("{}/api/event/fire".format(HTTP_BASE_URL), data={"event_name":"test_event_with_bad_data", req = requests.post("{}/api/event/fire".format(HTTP_BASE_URL),
"event_data":'not json', data={"event_name":"test_event_with_bad_data",
"api_password":API_PASSWORD}) "event_data":'not json',
"api_password":API_PASSWORD})
# It shouldn't but if it fires, allow the event to take place # It shouldn't but if it fires, allow the event to take place