Addd basic http sessions to the http component

This commit is contained in:
jamespcole 2015-05-18 23:54:32 +10:00
parent 8f51741c65
commit 721dc6dae4

View File

@ -77,7 +77,12 @@ import logging
import time
import gzip
import os
import random
import string
from datetime import timedelta
from homeassistant.util import Throttle
from http.server import SimpleHTTPRequestHandler, HTTPServer
from http import cookies
from socketserver import ThreadingMixIn
from urllib.parse import urlparse, parse_qs
@ -90,6 +95,7 @@ from homeassistant.const import (
HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY)
import homeassistant.remote as rem
import homeassistant.util as util
import homeassistant.util.dt as date_util
import homeassistant.bootstrap as bootstrap
DOMAIN = "http"
@ -99,9 +105,16 @@ CONF_API_PASSWORD = "api_password"
CONF_SERVER_HOST = "server_host"
CONF_SERVER_PORT = "server_port"
CONF_DEVELOPMENT = "development"
CONF_SESSIONS_ENABLED = "sessions_enabled"
DATA_API_PASSWORD = 'api_password'
# Throttling time in seconds for expired sessions check
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
SESSION_TIMEOUT_SECONDS = 1800
SESSION_LOCK = threading.RLock()
SESSION_KEY = 'sessionId'
_LOGGER = logging.getLogger(__name__)
@ -125,9 +138,11 @@ def setup(hass, config=None):
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False)
server = HomeAssistantHTTPServer(
(server_host, server_port), RequestHandler, hass, api_password,
development, no_password_set)
development, no_password_set, sessions_enabled)
hass.bus.listen_once(
ha.EVENT_HOMEASSISTANT_START,
@ -139,7 +154,7 @@ def setup(hass, config=None):
return True
# pylint: disable=too-many-instance-attributes
class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
""" Handle HTTP requests in a threaded fashion. """
# pylint: disable=too-few-public-methods
@ -149,7 +164,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
# pylint: disable=too-many-arguments
def __init__(self, server_address, request_handler_class,
hass, api_password, development, no_password_set):
hass, api_password, development, no_password_set,
sessions_enabled):
super().__init__(server_address, request_handler_class)
self.server_address = server_address
@ -158,6 +174,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
self.development = development
self.no_password_set = no_password_set
self.paths = []
self.sessions_enabled = sessions_enabled
self._sessions = {}
# We will lazy init this one if needed
self.event_forwarder = None
@ -185,6 +203,51 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
""" Regitsters a path wit the server. """
self.paths.append((method, url, callback, require_auth))
@Throttle(MIN_SEC_SESSION_CLEARING)
def remove_expired_sessions(self):
""" Reemove any expired sessions. """
if SESSION_LOCK.acquire(False):
try:
keys = []
for key in self._sessions.keys():
keys.append(key)
for key in keys:
if self._sessions[key].is_expired:
del self._sessions[key]
_LOGGER.info("Cleared expired session %s", key)
finally:
SESSION_LOCK.release()
def add_session(self, key, session):
""" Add a new session to the list of tracked sessions """
self.remove_expired_sessions()
try:
SESSION_LOCK.acquire()
self._sessions[key] = session
finally:
SESSION_LOCK.release()
def get_session(self, key):
""" get a session by key """
self.remove_expired_sessions()
session = self._sessions.get(key, None)
if session is not None and session.is_expired:
return None
return session
def create_session(self, api_password):
""" Creates a new session and adds it to the sessions """
if self.sessions_enabled is not True:
return None
chars = string.ascii_letters + string.digits
session_id = ''.join([random.choice(chars) for i in range(20)])
session = ServerSession(session_id)
session.cookie_values['api_password'] = api_password
self.add_session(session_id, session)
return session
# pylint: disable=too-many-public-methods,too-many-locals
class RequestHandler(SimpleHTTPRequestHandler):
@ -197,6 +260,11 @@ class RequestHandler(SimpleHTTPRequestHandler):
server_version = "HomeAssistant/1.0"
def __init__(self, req, client_addr, server):
""" Contructor, call the base constructor and set up session """
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
self._session = None
def _handle_request(self, method): # pylint: disable=too-many-branches
""" Does some common checks and calls appropriate method. """
url = urlparse(self.path)
@ -225,6 +293,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
return
self._session = self.get_session()
if self.server.no_password_set:
api_password = self.server.api_password
else:
@ -233,6 +302,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
if not api_password and DATA_API_PASSWORD in data:
api_password = data[DATA_API_PASSWORD]
if not api_password and self._session is not None:
api_password = self._session.cookie_values.get('api_password')
if '_METHOD' in data:
method = data.pop('_METHOD')
@ -271,6 +343,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
else:
if self._session is None and require_auth:
self._session = self.server.create_session(api_password)
handle_request_method(self, path_match, data)
elif path_matched_but_not_method:
@ -313,6 +388,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
if location:
self.send_header('Location', location)
self.set_session_cookie_header()
self.end_headers()
if data is not None:
@ -342,6 +419,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
self.set_cache_header()
self.set_session_cookie_header()
if do_gzip:
gzip_data = gzip.compress(inp.read())
@ -377,3 +455,57 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.send_header(
HTTP_HEADER_EXPIRES,
self.date_time_string(time.time()+cache_time))
def set_session_cookie_header(self):
""" Add the header for the session cookie """
if self.server.sessions_enabled and self._session is not None:
cookie = cookies.SimpleCookie()
existing_sess_id = None
if self.headers.get('Cookie', None) is not None:
cookie.load(self.headers.get('Cookie'))
if cookie.get(SESSION_KEY, False):
existing_sess_id = cookie[SESSION_KEY].value
if existing_sess_id != self._session.session_id:
self.send_header(
'Set-Cookie',
SESSION_KEY+'='+self._session.session_id)
def get_session(self):
""" Get the requested session object from cookie value """
if self.server.sessions_enabled is not True:
return None
cookie = cookies.SimpleCookie()
if self.headers.get('Cookie', None) is not None:
cookie.load(self.headers.get("Cookie"))
if cookie.get(SESSION_KEY, False):
session = self.server.get_session(cookie[SESSION_KEY].value)
if session is not None:
session.reset_expiry()
return session
else:
return None
class ServerSession:
""" A very simple session class """
def __init__(self, session_id):
""" Set up the expiry time on creation """
self._expiry = 0
self.reset_expiry()
self.cookie_values = {}
self.session_id = session_id
def reset_expiry(self):
""" Resets the expiry based on current time """
self._expiry = date_util.utcnow() + timedelta(
seconds=SESSION_TIMEOUT_SECONDS)
@property
def is_expired(self):
""" Return true if the session is expired based on the expiry time """
return self._expiry < date_util.utcnow()