Merge pull request #679 from balloob/bugfixes

Bugfixes
This commit is contained in:
Paulus Schoutsen 2015-11-28 23:22:33 -08:00
commit 45bd371cbf
26 changed files with 716 additions and 711 deletions

View File

@ -18,10 +18,10 @@ from homeassistant.bootstrap import ERROR_LOG_FILENAME
from homeassistant.const import ( from homeassistant.const import (
URL_API, URL_API_STATES, URL_API_EVENTS, URL_API_SERVICES, URL_API_STREAM, URL_API, URL_API_STATES, URL_API_EVENTS, URL_API_SERVICES, URL_API_STREAM,
URL_API_EVENT_FORWARD, URL_API_STATES_ENTITY, URL_API_COMPONENTS, URL_API_EVENT_FORWARD, URL_API_STATES_ENTITY, URL_API_COMPONENTS,
URL_API_CONFIG, URL_API_BOOTSTRAP, URL_API_ERROR_LOG, URL_API_CONFIG, URL_API_BOOTSTRAP, URL_API_ERROR_LOG, URL_API_LOG_OUT,
EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, MATCH_ALL,
HTTP_OK, HTTP_CREATED, HTTP_BAD_REQUEST, HTTP_NOT_FOUND, HTTP_OK, HTTP_CREATED, HTTP_BAD_REQUEST, HTTP_NOT_FOUND,
HTTP_UNPROCESSABLE_ENTITY, CONTENT_TYPE_TEXT_PLAIN) HTTP_UNPROCESSABLE_ENTITY)
DOMAIN = 'api' DOMAIN = 'api'
@ -36,10 +36,6 @@ _LOGGER = logging.getLogger(__name__)
def setup(hass, config): def setup(hass, config):
""" Register the API with the HTTP interface. """ """ Register the API with the HTTP interface. """
if 'http' not in hass.config.components:
_LOGGER.error('Dependency http is not loaded')
return False
# /api - for validation purposes # /api - for validation purposes
hass.http.register_path('GET', URL_API, _handle_get_api) hass.http.register_path('GET', URL_API, _handle_get_api)
@ -93,6 +89,8 @@ def setup(hass, config):
hass.http.register_path('GET', URL_API_ERROR_LOG, hass.http.register_path('GET', URL_API_ERROR_LOG,
_handle_get_api_error_log) _handle_get_api_error_log)
hass.http.register_path('POST', URL_API_LOG_OUT, _handle_post_api_log_out)
return True return True
@ -108,6 +106,7 @@ def _handle_get_api_stream(handler, path_match, data):
wfile = handler.wfile wfile = handler.wfile
write_lock = threading.Lock() write_lock = threading.Lock()
block = threading.Event() block = threading.Event()
session_id = None
restrict = data.get('restrict') restrict = data.get('restrict')
if restrict: if restrict:
@ -121,6 +120,7 @@ def _handle_get_api_stream(handler, path_match, data):
try: try:
wfile.write(msg.encode("UTF-8")) wfile.write(msg.encode("UTF-8"))
wfile.flush() wfile.flush()
handler.server.sessions.extend_validation(session_id)
except IOError: except IOError:
block.set() block.set()
@ -140,6 +140,7 @@ def _handle_get_api_stream(handler, path_match, data):
handler.send_response(HTTP_OK) handler.send_response(HTTP_OK)
handler.send_header('Content-type', 'text/event-stream') handler.send_header('Content-type', 'text/event-stream')
session_id = handler.set_session_cookie_header()
handler.end_headers() handler.end_headers()
hass.bus.listen(MATCH_ALL, forward_events) hass.bus.listen(MATCH_ALL, forward_events)
@ -347,9 +348,15 @@ def _handle_get_api_components(handler, path_match, data):
def _handle_get_api_error_log(handler, path_match, data): def _handle_get_api_error_log(handler, path_match, data):
""" Returns the logged errors for this session. """ """ Returns the logged errors for this session. """
error_path = handler.server.hass.config.path(ERROR_LOG_FILENAME) handler.write_file(handler.server.hass.config.path(ERROR_LOG_FILENAME),
with open(error_path, 'rb') as error_log: False)
handler.write_file_pointer(CONTENT_TYPE_TEXT_PLAIN, error_log)
def _handle_post_api_log_out(handler, path_match, data):
""" Log user out. """
handler.send_response(HTTP_OK)
handler.destroy_session()
handler.end_headers()
def _services_json(hass): def _services_json(hass):

View File

@ -80,19 +80,21 @@ def setup(hass, config):
def _proxy_camera_image(handler, path_match, data): def _proxy_camera_image(handler, path_match, data):
""" Proxies the camera image via the HA server. """ """ Proxies the camera image via the HA server. """
entity_id = path_match.group(ATTR_ENTITY_ID) entity_id = path_match.group(ATTR_ENTITY_ID)
camera = component.entities.get(entity_id)
camera = None if camera is None:
if entity_id in component.entities.keys():
camera = component.entities[entity_id]
if camera:
response = camera.camera_image()
if response is not None:
handler.wfile.write(response)
else:
handler.send_response(HTTP_NOT_FOUND)
else:
handler.send_response(HTTP_NOT_FOUND) handler.send_response(HTTP_NOT_FOUND)
handler.end_headers()
return
response = camera.camera_image()
if response is None:
handler.send_response(HTTP_NOT_FOUND)
handler.end_headers()
return
handler.wfile.write(response)
hass.http.register_path( hass.http.register_path(
'GET', 'GET',
@ -108,12 +110,9 @@ def setup(hass, config):
stream even with only a still image URL available. stream even with only a still image URL available.
""" """
entity_id = path_match.group(ATTR_ENTITY_ID) entity_id = path_match.group(ATTR_ENTITY_ID)
camera = component.entities.get(entity_id)
camera = None if camera is None:
if entity_id in component.entities.keys():
camera = component.entities[entity_id]
if not camera:
handler.send_response(HTTP_NOT_FOUND) handler.send_response(HTTP_NOT_FOUND)
handler.end_headers() handler.end_headers()
return return
@ -131,7 +130,6 @@ def setup(hass, config):
# MJPEG_START_HEADER.format() # MJPEG_START_HEADER.format()
while True: while True:
img_bytes = camera.camera_image() img_bytes = camera.camera_image()
if img_bytes is None: if img_bytes is None:
continue continue
@ -148,12 +146,12 @@ def setup(hass, config):
handler.request.sendall( handler.request.sendall(
bytes('--jpgboundary\r\n', 'utf-8')) bytes('--jpgboundary\r\n', 'utf-8'))
time.sleep(0.5)
except (requests.RequestException, IOError): except (requests.RequestException, IOError):
camera.is_streaming = False camera.is_streaming = False
camera.update_ha_state() camera.update_ha_state()
camera.is_streaming = False
hass.http.register_path( hass.http.register_path(
'GET', 'GET',
re.compile( re.compile(

View File

@ -4,8 +4,8 @@ homeassistant.components.camera.demo
Demo platform that has a fake camera. Demo platform that has a fake camera.
""" """
import os import os
from random import randint
from homeassistant.components.camera import Camera from homeassistant.components.camera import Camera
import homeassistant.util.dt as dt_util
def setup_platform(hass, config, add_devices, discovery_info=None): def setup_platform(hass, config, add_devices, discovery_info=None):
@ -24,12 +24,12 @@ class DemoCamera(Camera):
def camera_image(self): def camera_image(self):
""" Return a faked still image response. """ """ Return a faked still image response. """
now = dt_util.utcnow()
image_path = os.path.join(os.path.dirname(__file__), image_path = os.path.join(os.path.dirname(__file__),
'demo_{}.png'.format(randint(1, 5))) 'demo_{}.jpg'.format(now.second % 4))
with open(image_path, 'rb') as file: with open(image_path, 'rb') as file:
output = file.read() return file.read()
return output
@property @property
def name(self): def name(self):

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.5 KiB

View File

@ -54,8 +54,7 @@ def setup(hass, config):
def _handle_get_root(handler, path_match, data): def _handle_get_root(handler, path_match, data):
""" Renders the debug interface. """ """ Renders the frontend. """
handler.send_response(HTTP_OK) handler.send_response(HTTP_OK)
handler.send_header('Content-type', 'text/html; charset=utf-8') handler.send_header('Content-type', 'text/html; charset=utf-8')
handler.end_headers() handler.end_headers()
@ -66,7 +65,7 @@ def _handle_get_root(handler, path_match, data):
app_url = "frontend-{}.html".format(version.VERSION) app_url = "frontend-{}.html".format(version.VERSION)
# auto login if no password was set, else check api_password param # auto login if no password was set, else check api_password param
auth = ('no_password_set' if handler.server.no_password_set auth = ('no_password_set' if handler.server.api_password is None
else data.get('api_password', '')) else data.get('api_password', ''))
with open(INDEX_PATH) as template_file: with open(INDEX_PATH) as template_file:

View File

@ -4,16 +4,13 @@
<meta charset="utf-8"> <meta charset="utf-8">
<title>Home Assistant</title> <title>Home Assistant</title>
<link rel='manifest' href='/static/manifest.json' /> <link rel='manifest' href='/static/manifest.json'>
<link rel='shortcut icon' href='/static/favicon.ico' /> <link rel='icon' href='/static/favicon.ico'>
<link rel='icon' type='image/png'
href='/static/favicon-192x192.png' sizes='192x192'>
<link rel='apple-touch-icon' sizes='180x180' <link rel='apple-touch-icon' sizes='180x180'
href='/static/favicon-apple-180x180.png'> href='/static/favicon-apple-180x180.png'>
<meta name='apple-mobile-web-app-capable' content='yes'> <meta name='apple-mobile-web-app-capable' content='yes'>
<meta name='mobile-web-app-capable' content='yes'> <meta name='mobile-web-app-capable' content='yes'>
<meta name='viewport' content='width=device-width, <meta name='viewport' content='width=device-width, user-scalable=no'>
user-scalable=no' />
<meta name='theme-color' content='#03a9f4'> <meta name='theme-color' content='#03a9f4'>
<style> <style>
#init { #init {
@ -26,24 +23,17 @@
justify-content: center; justify-content: center;
align-items: center; align-items: center;
text-align: center; text-align: center;
font-family: 'Roboto', 'Noto', sans-serif;
position: fixed; position: fixed;
top: 0; top: 0;
left: 0; left: 0;
right: 0; right: 0;
bottom: 0; bottom: 0;
} margin-bottom: 123px;
#init div {
line-height: 34px;
margin-bottom: 89px;
} }
</style> </style>
</head> </head>
<body fullbleed> <body fullbleed>
<div id='init'> <div id='init'><img src='/static/favicon-192x192.png' height='192'></div>
<img src='/static/splash.png' height='230' />
<div>Initializing</div>
</div>
<script src='/static/webcomponents-lite.min.js'></script> <script src='/static/webcomponents-lite.min.js'></script>
<link rel='import' href='/static/{{ app_url }}' /> <link rel='import' href='/static/{{ app_url }}' />
<home-assistant auth='{{ auth }}' icons='{{ icons }}'></home-assistant> <home-assistant auth='{{ auth }}' icons='{{ icons }}'></home-assistant>

View File

@ -1,2 +1,2 @@
""" DO NOT MODIFY. Auto-generated by build_frontend script """ """ DO NOT MODIFY. Auto-generated by build_frontend script """
VERSION = "c90d40a0240cc1feec791ee820d928b3" VERSION = "36df87bb6c219a2ee59adf416e3abdfa"

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

File diff suppressed because one or more lines are too long

@ -1 +1 @@
Subproject commit 62e494bd04509e8d9b73354b0e17d3381955e0c8 Subproject commit 33124030f6d119ad3a58cb520062f2aa58022c6d

View File

@ -3,12 +3,17 @@
"short_name": "Assistant", "short_name": "Assistant",
"start_url": "/", "start_url": "/",
"display": "standalone", "display": "standalone",
"theme_color": "#03A9F4",
"icons": [ "icons": [
{ {
"src": "\/static\/favicon-192x192.png", "src": "/static/favicon-192x192.png",
"sizes": "192x192", "sizes": "192x192",
"type": "image\/png", "type": "image/png",
"density": "4.0" },
{
"src": "/static/favicon-384x384.png",
"sizes": "384x384",
"type": "image/png",
} }
] ]
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

File diff suppressed because one or more lines are too long

View File

@ -12,10 +12,7 @@ import logging
import time import time
import gzip import gzip
import os import os
import random
import string
from datetime import timedelta from datetime import timedelta
from homeassistant.util import Throttle
from http.server import SimpleHTTPRequestHandler, HTTPServer from http.server import SimpleHTTPRequestHandler, HTTPServer
from http import cookies from http import cookies
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
@ -44,40 +41,30 @@ CONF_SESSIONS_ENABLED = "sessions_enabled"
DATA_API_PASSWORD = 'api_password' DATA_API_PASSWORD = 'api_password'
# Throttling time in seconds for expired sessions check # Throttling time in seconds for expired sessions check
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20) SESSION_CLEAR_INTERVAL = timedelta(seconds=20)
SESSION_TIMEOUT_SECONDS = 1800 SESSION_TIMEOUT_SECONDS = 1800
SESSION_KEY = 'sessionId' SESSION_KEY = 'sessionId'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def setup(hass, config=None): def setup(hass, config):
""" Sets up the HTTP API and debug interface. """ """ Sets up the HTTP API and debug interface. """
if config is None or DOMAIN not in config: conf = config[DOMAIN]
config = {DOMAIN: {}}
api_password = util.convert(config[DOMAIN].get(CONF_API_PASSWORD), str) api_password = util.convert(conf.get(CONF_API_PASSWORD), str)
no_password_set = api_password is None
if no_password_set:
api_password = util.get_random_string()
# If no server host is given, accept all incoming requests # If no server host is given, accept all incoming requests
server_host = config[DOMAIN].get(CONF_SERVER_HOST, '0.0.0.0') server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0')
server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT)
server_port = config[DOMAIN].get(CONF_SERVER_PORT, SERVER_PORT) development = str(conf.get(CONF_DEVELOPMENT, "")) == "1"
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
try: try:
server = HomeAssistantHTTPServer( server = HomeAssistantHTTPServer(
(server_host, server_port), RequestHandler, hass, api_password, (server_host, server_port), RequestHandler, hass, api_password,
development, no_password_set, sessions_enabled) development)
except OSError: except OSError:
# Happens if address already in use # If address already in use
_LOGGER.exception("Error setting up HTTP server") _LOGGER.exception("Error setting up HTTP server")
return False return False
@ -102,17 +89,15 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, server_address, request_handler_class, def __init__(self, server_address, request_handler_class,
hass, api_password, development, no_password_set, hass, api_password, development):
sessions_enabled):
super().__init__(server_address, request_handler_class) super().__init__(server_address, request_handler_class)
self.server_address = server_address self.server_address = server_address
self.hass = hass self.hass = hass
self.api_password = api_password self.api_password = api_password
self.development = development self.development = development
self.no_password_set = no_password_set
self.paths = [] self.paths = []
self.sessions = SessionStore(sessions_enabled) self.sessions = SessionStore()
# We will lazy init this one if needed # We will lazy init this one if needed
self.event_forwarder = None self.event_forwarder = None
@ -161,12 +146,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
def __init__(self, req, client_addr, server): def __init__(self, req, client_addr, server):
""" Contructor, call the base constructor and set up session """ """ Contructor, call the base constructor and set up session """
self._session = None # Track if this was an authenticated request
self.authenticated = False
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server) SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
def log_message(self, fmt, *arguments): def log_message(self, fmt, *arguments):
""" Redirect built-in log to HA logging """ """ Redirect built-in log to HA logging """
if self.server.no_password_set: if self.server.api_password is None:
_LOGGER.info(fmt, *arguments) _LOGGER.info(fmt, *arguments)
else: else:
_LOGGER.info( _LOGGER.info(
@ -201,18 +187,17 @@ class RequestHandler(SimpleHTTPRequestHandler):
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY) "Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
return return
self._session = self.get_session() if self.server.api_password is None:
if self.server.no_password_set: self.authenticated = True
api_password = self.server.api_password elif HTTP_HEADER_HA_AUTH in self.headers:
else:
api_password = self.headers.get(HTTP_HEADER_HA_AUTH) api_password = self.headers.get(HTTP_HEADER_HA_AUTH)
if not api_password and DATA_API_PASSWORD in data: if not api_password and DATA_API_PASSWORD in data:
api_password = data[DATA_API_PASSWORD] api_password = data[DATA_API_PASSWORD]
if not api_password and self._session is not None: self.authenticated = api_password == self.server.api_password
api_password = self._session.cookie_values.get( else:
CONF_API_PASSWORD) self.authenticated = self.verify_session()
if '_METHOD' in data: if '_METHOD' in data:
method = data.pop('_METHOD') method = data.pop('_METHOD')
@ -245,18 +230,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
# Did we find a handler for the incoming request? # Did we find a handler for the incoming request?
if handle_request_method: if handle_request_method:
# For some calls we need a valid password # For some calls we need a valid password
if require_auth and api_password != self.server.api_password: if require_auth and not self.authenticated:
self.write_json_message( self.write_json_message(
"API password missing or incorrect.", HTTP_UNAUTHORIZED) "API password missing or incorrect.", HTTP_UNAUTHORIZED)
return
else: handle_request_method(self, path_match, data)
if self._session is None and require_auth:
self._session = self.server.sessions.create(
api_password)
handle_request_method(self, path_match, data)
elif path_matched_but_not_method: elif path_matched_but_not_method:
self.send_response(HTTP_METHOD_NOT_ALLOWED) self.send_response(HTTP_METHOD_NOT_ALLOWED)
@ -307,18 +287,19 @@ class RequestHandler(SimpleHTTPRequestHandler):
json.dumps(data, indent=4, sort_keys=True, json.dumps(data, indent=4, sort_keys=True,
cls=rem.JSONEncoder).encode("UTF-8")) cls=rem.JSONEncoder).encode("UTF-8"))
def write_file(self, path): def write_file(self, path, cache_headers=True):
""" Returns a file to the user. """ """ Returns a file to the user. """
try: try:
with open(path, 'rb') as inp: with open(path, 'rb') as inp:
self.write_file_pointer(self.guess_type(path), inp) self.write_file_pointer(self.guess_type(path), inp,
cache_headers)
except IOError: except IOError:
self.send_response(HTTP_NOT_FOUND) self.send_response(HTTP_NOT_FOUND)
self.end_headers() self.end_headers()
_LOGGER.exception("Unable to serve %s", path) _LOGGER.exception("Unable to serve %s", path)
def write_file_pointer(self, content_type, inp): def write_file_pointer(self, content_type, inp, cache_headers=True):
""" """
Helper function to write a file pointer to the user. Helper function to write a file pointer to the user.
Does not do error handling. Does not do error handling.
@ -328,7 +309,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.send_response(HTTP_OK) self.send_response(HTTP_OK)
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type) self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
self.set_cache_header() if cache_headers:
self.set_cache_header()
self.set_session_cookie_header() self.set_session_cookie_header()
if do_gzip: if do_gzip:
@ -355,75 +337,81 @@ class RequestHandler(SimpleHTTPRequestHandler):
def set_cache_header(self): def set_cache_header(self):
""" Add cache headers if not in development """ """ Add cache headers if not in development """
if not self.server.development: if self.server.development:
# 1 year in seconds return
cache_time = 365 * 86400
self.send_header( # 1 year in seconds
HTTP_HEADER_CACHE_CONTROL, cache_time = 365 * 86400
"public, max-age={}".format(cache_time))
self.send_header( self.send_header(
HTTP_HEADER_EXPIRES, HTTP_HEADER_CACHE_CONTROL,
self.date_time_string(time.time()+cache_time)) "public, max-age={}".format(cache_time))
self.send_header(
HTTP_HEADER_EXPIRES,
self.date_time_string(time.time()+cache_time))
def set_session_cookie_header(self): def set_session_cookie_header(self):
""" Add the header for the session cookie """ """ Add the header for the session cookie and return session id. """
if self.server.sessions.enabled and self._session is not None: if not self.authenticated:
existing_sess_id = self.get_current_session_id() return
if existing_sess_id != self._session.session_id: session_id = self.get_cookie_session_id()
self.send_header(
'Set-Cookie',
SESSION_KEY+'='+self._session.session_id)
def get_session(self):
""" Get the requested session object from cookie value """
if self.server.sessions.enabled is not True:
return None
session_id = self.get_current_session_id()
if session_id is not None: if session_id is not None:
session = self.server.sessions.get(session_id) self.server.sessions.extend_validation(session_id)
if session is not None: return
session.reset_expiry()
return session
return None self.send_header(
'Set-Cookie',
'{}={}'.format(SESSION_KEY, self.server.sessions.create())
)
def get_current_session_id(self): return session_id
def verify_session(self):
""" Verify that we are in a valid session. """
return self.get_cookie_session_id() is not None
def get_cookie_session_id(self):
""" """
Extracts the current session id from the Extracts the current session id from the
cookie or returns None if not set cookie or returns None if not set or invalid
""" """
if 'Cookie' not in self.headers:
return None
cookie = cookies.SimpleCookie() cookie = cookies.SimpleCookie()
try:
cookie.load(self.headers["Cookie"])
except cookies.CookieError:
return None
if self.headers.get('Cookie', None) is not None: morsel = cookie.get(SESSION_KEY)
cookie.load(self.headers.get("Cookie"))
if cookie.get(SESSION_KEY, False): if morsel is None:
return cookie[SESSION_KEY].value return None
session_id = cookie[SESSION_KEY].value
if self.server.sessions.is_valid(session_id):
return session_id
return None return None
def destroy_session(self):
""" Destroys session. """
session_id = self.get_cookie_session_id()
class ServerSession: if session_id is None:
""" A very simple session class """ return
def __init__(self, session_id):
""" Set up the expiry time on creation """
self._expiry = 0
self.reset_expiry()
self.cookie_values = {}
self.session_id = session_id
def reset_expiry(self): self.send_header('Set-Cookie', '')
""" Resets the expiry based on current time """ self.server.sessions.destroy(session_id)
self._expiry = date_util.utcnow() + timedelta(
seconds=SESSION_TIMEOUT_SECONDS)
@property
def is_expired(self): def session_valid_time():
""" Return true if the session is expired based on the expiry time """ """ Time till when a session will be valid. """
return self._expiry < date_util.utcnow() return date_util.utcnow() + timedelta(seconds=SESSION_TIMEOUT_SECONDS)
class SessionStore(object): class SessionStore(object):
@ -431,47 +419,42 @@ class SessionStore(object):
def __init__(self, enabled=True): def __init__(self, enabled=True):
""" Set up the session store """ """ Set up the session store """
self._sessions = {} self._sessions = {}
self.enabled = enabled self.lock = threading.RLock()
self.session_lock = threading.RLock()
@Throttle(MIN_SEC_SESSION_CLEARING) @util.Throttle(SESSION_CLEAR_INTERVAL)
def remove_expired(self): def _remove_expired(self):
""" Remove any expired sessions. """ """ Remove any expired sessions. """
if self.session_lock.acquire(False): now = date_util.utcnow()
try: for key in [key for key, valid_time in self._sessions.items()
keys = [] if valid_time < now]:
for key in self._sessions.keys(): self._sessions.pop(key)
keys.append(key)
for key in keys: def is_valid(self, key):
if self._sessions[key].is_expired: """ Return True if a valid session is given. """
del self._sessions[key] with self.lock:
_LOGGER.info("Cleared expired session %s", key) self._remove_expired()
finally:
self.session_lock.release()
def add(self, key, session): return (key in self._sessions and
""" Add a new session to the list of tracked sessions """ self._sessions[key] > date_util.utcnow())
self.remove_expired()
with self.session_lock:
self._sessions[key] = session
def get(self, key): def extend_validation(self, key):
""" get a session by key """ """ Extend a session validation time. """
self.remove_expired() with self.lock:
session = self._sessions.get(key, None) self._sessions[key] = session_valid_time()
if session is not None and session.is_expired:
return None
return session
def create(self, api_password): def destroy(self, key):
""" Creates a new session and adds it to the sessions """ """ Destroy a session by key. """
if self.enabled is not True: with self.lock:
return None self._sessions.pop(key, None)
chars = string.ascii_letters + string.digits def create(self):
session_id = ''.join([random.choice(chars) for i in range(20)]) """ Creates a new session. """
session = ServerSession(session_id) with self.lock:
session.cookie_values[CONF_API_PASSWORD] = api_password session_id = util.get_random_string(20)
self.add(session_id, session)
return session while session_id in self._sessions:
session_id = util.get_random_string(20)
self._sessions[session_id] = session_valid_time()
return session_id

View File

@ -100,7 +100,7 @@ class PushBulletNotificationService(BaseNotificationService):
# This also seems works to send to all devices in own account # This also seems works to send to all devices in own account
if ttype == 'email': if ttype == 'email':
self.pushbullet.push_note(title, message, email=tname) self.pushbullet.push_note(title, message, email=tname)
_LOGGER.info('Sent notification to self') _LOGGER.info('Sent notification to email %s', tname)
continue continue
# Refresh if name not found. While awaiting periodic refresh # Refresh if name not found. While awaiting periodic refresh
@ -108,18 +108,21 @@ class PushBulletNotificationService(BaseNotificationService):
if ttype not in self.pbtargets: if ttype not in self.pbtargets:
_LOGGER.error('Invalid target syntax: %s', target) _LOGGER.error('Invalid target syntax: %s', target)
continue continue
if tname.lower() not in self.pbtargets[ttype] and not refreshed:
tname = tname.lower()
if tname not in self.pbtargets[ttype] and not refreshed:
self.refresh() self.refresh()
refreshed = True refreshed = True
# Attempt push_note on a dict value. Keys are types & target # Attempt push_note on a dict value. Keys are types & target
# name. Dict pbtargets has all *actual* targets. # name. Dict pbtargets has all *actual* targets.
try: try:
self.pbtargets[ttype][tname.lower()].push_note(title, message) self.pbtargets[ttype][tname].push_note(title, message)
_LOGGER.info('Sent notification to %s/%s', ttype, tname)
except KeyError: except KeyError:
_LOGGER.error('No such target: %s/%s', ttype, tname) _LOGGER.error('No such target: %s/%s', ttype, tname)
continue continue
except self.pushbullet.errors.PushError: except self.pushbullet.errors.PushError:
_LOGGER.error('Notify failed to: %s/%s', ttype, tname) _LOGGER.error('Notify failed to: %s/%s', ttype, tname)
continue continue
_LOGGER.info('Sent notification to %s/%s', ttype, tname)

View File

@ -164,6 +164,7 @@ URL_API_EVENT_FORWARD = "/api/event_forwarding"
URL_API_COMPONENTS = "/api/components" URL_API_COMPONENTS = "/api/components"
URL_API_BOOTSTRAP = "/api/bootstrap" URL_API_BOOTSTRAP = "/api/bootstrap"
URL_API_ERROR_LOG = "/api/error_log" URL_API_ERROR_LOG = "/api/error_log"
URL_API_LOG_OUT = "/api/log_out"
HTTP_OK = 200 HTTP_OK = 200
HTTP_CREATED = 201 HTTP_CREATED = 201

View File

@ -4,6 +4,8 @@ homeassistant.helpers.entity_component
Provides helpers for components that manage entities. Provides helpers for components that manage entities.
""" """
from threading import Lock
from homeassistant.bootstrap import prepare_setup_platform from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.helpers import ( from homeassistant.helpers import (
generate_entity_id, config_per_platform, extract_entity_ids) generate_entity_id, config_per_platform, extract_entity_ids)
@ -37,6 +39,7 @@ class EntityComponent(object):
self.is_polling = False self.is_polling = False
self.config = None self.config = None
self.lock = Lock()
def setup(self, config): def setup(self, config):
""" """
@ -61,8 +64,11 @@ class EntityComponent(object):
Takes in a list of new entities. For each entity will see if it already Takes in a list of new entities. For each entity will see if it already
exists. If not, will add it, set it up and push the first state. exists. If not, will add it, set it up and push the first state.
""" """
for entity in new_entities: with self.lock:
if entity is not None and entity not in self.entities.values(): for entity in new_entities:
if entity is None or entity in self.entities.values():
continue
entity.hass = self.hass entity.hass = self.hass
if getattr(entity, 'entity_id', None) is None: if getattr(entity, 'entity_id', None) is None:
@ -74,23 +80,33 @@ class EntityComponent(object):
entity.update_ha_state() entity.update_ha_state()
if self.group is None and self.group_name is not None: if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name, self.group = group.Group(self.hass, self.group_name,
user_defined=False) user_defined=False)
if self.group is not None: if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys()) self.group.update_tracked_entity_ids(self.entities.keys())
self._start_polling() if self.is_polling or \
not any(entity.should_poll for entity
in self.entities.values()):
return
self.is_polling = True
track_utc_time_change(
self.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def extract_from_service(self, service): def extract_from_service(self, service):
""" """
Takes a service and extracts all known entities. Takes a service and extracts all known entities.
Will return all if no entity IDs given in service. Will return all if no entity IDs given in service.
""" """
if ATTR_ENTITY_ID not in service.data: with self.lock:
return self.entities.values() if ATTR_ENTITY_ID not in service.data:
else: return list(self.entities.values())
return [self.entities[entity_id] for entity_id return [self.entities[entity_id] for entity_id
in extract_entity_ids(self.hass, service) in extract_entity_ids(self.hass, service)
if entity_id in self.entities] if entity_id in self.entities]
@ -99,9 +115,10 @@ class EntityComponent(object):
""" Update the states of all the entities. """ """ Update the states of all the entities. """
self.logger.info("Updating %s entities", self.domain) self.logger.info("Updating %s entities", self.domain)
for entity in self.entities.values(): with self.lock:
if entity.should_poll: for entity in self.entities.values():
entity.update_ha_state(True) if entity.should_poll:
entity.update_ha_state(True)
def _entity_discovered(self, service, info): def _entity_discovered(self, service, info):
""" Called when a entity is discovered. """ """ Called when a entity is discovered. """
@ -110,18 +127,6 @@ class EntityComponent(object):
self._setup_platform(self.discovery_platforms[service], {}, info) self._setup_platform(self.discovery_platforms[service], {}, info)
def _start_polling(self):
""" Start polling entities if necessary. """
if self.is_polling or \
not any(entity.should_poll for entity in self.entities.values()):
return
self.is_polling = True
track_utc_time_change(
self.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def _setup_platform(self, platform_type, platform_config, def _setup_platform(self, platform_type, platform_config,
discovery_info=None): discovery_info=None):
""" Tries to setup a platform for this component. """ """ Tries to setup a platform for this component. """

View File

@ -8,14 +8,13 @@ Tests Home Assistant HTTP component does what it should do.
import unittest import unittest
import json import json
from unittest.mock import patch from unittest.mock import patch
import tempfile
import requests import requests
from homeassistant import bootstrap, const
import homeassistant.core as ha import homeassistant.core as ha
import homeassistant.bootstrap as bootstrap
import homeassistant.remote as remote
import homeassistant.components.http as http import homeassistant.components.http as http
from homeassistant.const import HTTP_HEADER_HA_AUTH
API_PASSWORD = "test1234" API_PASSWORD = "test1234"
@ -26,7 +25,7 @@ SERVER_PORT = 8120
HTTP_BASE_URL = "http://127.0.0.1:{}".format(SERVER_PORT) HTTP_BASE_URL = "http://127.0.0.1:{}".format(SERVER_PORT)
HA_HEADERS = {HTTP_HEADER_HA_AUTH: API_PASSWORD} HA_HEADERS = {const.HTTP_HEADER_HA_AUTH: API_PASSWORD}
hass = None hass = None
@ -68,20 +67,20 @@ class TestAPI(unittest.TestCase):
# TODO move back to http component and test with use_auth. # TODO move back to http component and test with use_auth.
def test_access_denied_without_password(self): def test_access_denied_without_password(self):
req = requests.get( req = requests.get(
_url(remote.URL_API_STATES_ENTITY.format("test"))) _url(const.URL_API_STATES_ENTITY.format("test")))
self.assertEqual(401, req.status_code) self.assertEqual(401, req.status_code)
def test_access_denied_with_wrong_password(self): def test_access_denied_with_wrong_password(self):
req = requests.get( req = requests.get(
_url(remote.URL_API_STATES_ENTITY.format("test")), _url(const.URL_API_STATES_ENTITY.format("test")),
headers={HTTP_HEADER_HA_AUTH: 'wrongpassword'}) headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
self.assertEqual(401, req.status_code) self.assertEqual(401, req.status_code)
def test_api_list_state_entities(self): def test_api_list_state_entities(self):
""" Test if the debug interface allows us to list state entities. """ """ Test if the debug interface allows us to list state entities. """
req = requests.get(_url(remote.URL_API_STATES), req = requests.get(_url(const.URL_API_STATES),
headers=HA_HEADERS) headers=HA_HEADERS)
remote_data = [ha.State.from_dict(item) for item in req.json()] remote_data = [ha.State.from_dict(item) for item in req.json()]
@ -91,7 +90,7 @@ class TestAPI(unittest.TestCase):
def test_api_get_state(self): def test_api_get_state(self):
""" Test if the debug interface allows us to get a state. """ """ Test if the debug interface allows us to get a state. """
req = requests.get( req = requests.get(
_url(remote.URL_API_STATES_ENTITY.format("test.test")), _url(const.URL_API_STATES_ENTITY.format("test.test")),
headers=HA_HEADERS) headers=HA_HEADERS)
data = ha.State.from_dict(req.json()) data = ha.State.from_dict(req.json())
@ -105,7 +104,7 @@ class TestAPI(unittest.TestCase):
def test_api_get_non_existing_state(self): def test_api_get_non_existing_state(self):
""" Test if the debug interface allows us to get a state. """ """ Test if the debug interface allows us to get a state. """
req = requests.get( req = requests.get(
_url(remote.URL_API_STATES_ENTITY.format("does_not_exist")), _url(const.URL_API_STATES_ENTITY.format("does_not_exist")),
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(404, req.status_code) self.assertEqual(404, req.status_code)
@ -115,7 +114,7 @@ class TestAPI(unittest.TestCase):
hass.states.set("test.test", "not_to_be_set") hass.states.set("test.test", "not_to_be_set")
requests.post(_url(remote.URL_API_STATES_ENTITY.format("test.test")), requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
data=json.dumps({"state": "debug_state_change2"}), data=json.dumps({"state": "debug_state_change2"}),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -130,7 +129,7 @@ class TestAPI(unittest.TestCase):
new_state = "debug_state_change" new_state = "debug_state_change"
req = requests.post( req = requests.post(
_url(remote.URL_API_STATES_ENTITY.format( _url(const.URL_API_STATES_ENTITY.format(
"test_entity.that_does_not_exist")), "test_entity.that_does_not_exist")),
data=json.dumps({'state': new_state}), data=json.dumps({'state': new_state}),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -146,7 +145,7 @@ class TestAPI(unittest.TestCase):
""" Test if API sends appropriate error if we omit state. """ """ Test if API sends appropriate error if we omit state. """
req = requests.post( req = requests.post(
_url(remote.URL_API_STATES_ENTITY.format( _url(const.URL_API_STATES_ENTITY.format(
"test_entity.that_does_not_exist")), "test_entity.that_does_not_exist")),
data=json.dumps({}), data=json.dumps({}),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -165,7 +164,7 @@ class TestAPI(unittest.TestCase):
hass.bus.listen_once("test.event_no_data", listener) hass.bus.listen_once("test.event_no_data", listener)
requests.post( requests.post(
_url(remote.URL_API_EVENTS_EVENT.format("test.event_no_data")), _url(const.URL_API_EVENTS_EVENT.format("test.event_no_data")),
headers=HA_HEADERS) headers=HA_HEADERS)
hass.pool.block_till_done() hass.pool.block_till_done()
@ -186,7 +185,7 @@ class TestAPI(unittest.TestCase):
hass.bus.listen_once("test_event_with_data", listener) hass.bus.listen_once("test_event_with_data", listener)
requests.post( requests.post(
_url(remote.URL_API_EVENTS_EVENT.format("test_event_with_data")), _url(const.URL_API_EVENTS_EVENT.format("test_event_with_data")),
data=json.dumps({"test": 1}), data=json.dumps({"test": 1}),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -206,7 +205,7 @@ class TestAPI(unittest.TestCase):
hass.bus.listen_once("test_event_bad_data", listener) hass.bus.listen_once("test_event_bad_data", listener)
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENTS_EVENT.format("test_event_bad_data")), _url(const.URL_API_EVENTS_EVENT.format("test_event_bad_data")),
data=json.dumps('not an object'), data=json.dumps('not an object'),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -217,7 +216,7 @@ class TestAPI(unittest.TestCase):
# Try now with valid but unusable JSON # Try now with valid but unusable JSON
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENTS_EVENT.format("test_event_bad_data")), _url(const.URL_API_EVENTS_EVENT.format("test_event_bad_data")),
data=json.dumps([1, 2, 3]), data=json.dumps([1, 2, 3]),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -226,9 +225,31 @@ class TestAPI(unittest.TestCase):
self.assertEqual(422, req.status_code) self.assertEqual(422, req.status_code)
self.assertEqual(0, len(test_value)) self.assertEqual(0, len(test_value))
def test_api_get_config(self):
req = requests.get(_url(const.URL_API_CONFIG),
headers=HA_HEADERS)
self.assertEqual(hass.config.as_dict(), req.json())
def test_api_get_components(self):
req = requests.get(_url(const.URL_API_COMPONENTS),
headers=HA_HEADERS)
self.assertEqual(hass.config.components, req.json())
def test_api_get_error_log(self):
test_content = 'Test String'
with tempfile.NamedTemporaryFile() as log:
log.write(test_content.encode('utf-8'))
log.flush()
with patch.object(hass.config, 'path', return_value=log.name):
req = requests.get(_url(const.URL_API_ERROR_LOG),
headers=HA_HEADERS)
self.assertEqual(test_content, req.text)
self.assertIsNone(req.headers.get('expires'))
def test_api_get_event_listeners(self): def test_api_get_event_listeners(self):
""" Test if we can get the list of events being listened for. """ """ Test if we can get the list of events being listened for. """
req = requests.get(_url(remote.URL_API_EVENTS), req = requests.get(_url(const.URL_API_EVENTS),
headers=HA_HEADERS) headers=HA_HEADERS)
local = hass.bus.listeners local = hass.bus.listeners
@ -241,7 +262,7 @@ class TestAPI(unittest.TestCase):
def test_api_get_services(self): def test_api_get_services(self):
""" Test if we can get a dict describing current services. """ """ Test if we can get a dict describing current services. """
req = requests.get(_url(remote.URL_API_SERVICES), req = requests.get(_url(const.URL_API_SERVICES),
headers=HA_HEADERS) headers=HA_HEADERS)
local_services = hass.services.services local_services = hass.services.services
@ -262,7 +283,7 @@ class TestAPI(unittest.TestCase):
hass.services.register("test_domain", "test_service", listener) hass.services.register("test_domain", "test_service", listener)
requests.post( requests.post(
_url(remote.URL_API_SERVICES_SERVICE.format( _url(const.URL_API_SERVICES_SERVICE.format(
"test_domain", "test_service")), "test_domain", "test_service")),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -283,7 +304,7 @@ class TestAPI(unittest.TestCase):
hass.services.register("test_domain", "test_service", listener) hass.services.register("test_domain", "test_service", listener)
requests.post( requests.post(
_url(remote.URL_API_SERVICES_SERVICE.format( _url(const.URL_API_SERVICES_SERVICE.format(
"test_domain", "test_service")), "test_domain", "test_service")),
data=json.dumps({"test": 1}), data=json.dumps({"test": 1}),
headers=HA_HEADERS) headers=HA_HEADERS)
@ -296,24 +317,24 @@ class TestAPI(unittest.TestCase):
""" Test setting up event forwarding. """ """ Test setting up event forwarding. """
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(400, req.status_code) self.assertEqual(400, req.status_code)
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({'host': '127.0.0.1'}), data=json.dumps({'host': '127.0.0.1'}),
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(400, req.status_code) self.assertEqual(400, req.status_code)
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({'api_password': 'bla-di-bla'}), data=json.dumps({'api_password': 'bla-di-bla'}),
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(400, req.status_code) self.assertEqual(400, req.status_code)
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({ data=json.dumps({
'api_password': 'bla-di-bla', 'api_password': 'bla-di-bla',
'host': '127.0.0.1', 'host': '127.0.0.1',
@ -323,7 +344,7 @@ class TestAPI(unittest.TestCase):
self.assertEqual(422, req.status_code) self.assertEqual(422, req.status_code)
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({ data=json.dumps({
'api_password': 'bla-di-bla', 'api_password': 'bla-di-bla',
'host': '127.0.0.1', 'host': '127.0.0.1',
@ -334,7 +355,7 @@ class TestAPI(unittest.TestCase):
# Setup a real one # Setup a real one
req = requests.post( req = requests.post(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({ data=json.dumps({
'api_password': API_PASSWORD, 'api_password': API_PASSWORD,
'host': '127.0.0.1', 'host': '127.0.0.1',
@ -345,13 +366,13 @@ class TestAPI(unittest.TestCase):
# Delete it again.. # Delete it again..
req = requests.delete( req = requests.delete(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({}), data=json.dumps({}),
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(400, req.status_code) self.assertEqual(400, req.status_code)
req = requests.delete( req = requests.delete(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({ data=json.dumps({
'host': '127.0.0.1', 'host': '127.0.0.1',
'port': 'abcd' 'port': 'abcd'
@ -360,7 +381,7 @@ class TestAPI(unittest.TestCase):
self.assertEqual(422, req.status_code) self.assertEqual(422, req.status_code)
req = requests.delete( req = requests.delete(
_url(remote.URL_API_EVENT_FORWARD), _url(const.URL_API_EVENT_FORWARD),
data=json.dumps({ data=json.dumps({
'host': '127.0.0.1', 'host': '127.0.0.1',
'port': SERVER_PORT 'port': SERVER_PORT