Merge pull request #134 from jamespcole/http-sessions

Added basic http sessions for authentication
This commit is contained in:
Paulus Schoutsen 2015-05-20 07:52:54 -07:00
commit 0466ae5e07
2 changed files with 147 additions and 3 deletions

View File

@ -68,7 +68,7 @@ device_sun_light_trigger:
# A comma separated list of states that have to be tracked as a single group
# Grouped states should share the same type of states (ON/OFF or HOME/NOT_HOME)
group:
group:
living_room:
- light.Bowl
- light.Ceiling

View File

@ -77,7 +77,12 @@ import logging
import time
import gzip
import os
import random
import string
from datetime import timedelta
from homeassistant.util import Throttle
from http.server import SimpleHTTPRequestHandler, HTTPServer
from http import cookies
from socketserver import ThreadingMixIn
from urllib.parse import urlparse, parse_qs
@ -90,6 +95,7 @@ from homeassistant.const import (
HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY)
import homeassistant.remote as rem
import homeassistant.util as util
import homeassistant.util.dt as date_util
import homeassistant.bootstrap as bootstrap
DOMAIN = "http"
@ -99,9 +105,15 @@ CONF_API_PASSWORD = "api_password"
CONF_SERVER_HOST = "server_host"
CONF_SERVER_PORT = "server_port"
CONF_DEVELOPMENT = "development"
CONF_SESSIONS_ENABLED = "sessions_enabled"
DATA_API_PASSWORD = 'api_password'
# Throttling time in seconds for expired sessions check
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
SESSION_TIMEOUT_SECONDS = 1800
SESSION_KEY = 'sessionId'
_LOGGER = logging.getLogger(__name__)
@ -125,9 +137,11 @@ def setup(hass, config=None):
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
server = HomeAssistantHTTPServer(
(server_host, server_port), RequestHandler, hass, api_password,
development, no_password_set)
development, no_password_set, sessions_enabled)
hass.bus.listen_once(
ha.EVENT_HOMEASSISTANT_START,
@ -140,6 +154,7 @@ def setup(hass, config=None):
return True
# pylint: disable=too-many-instance-attributes
class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
""" Handle HTTP requests in a threaded fashion. """
# pylint: disable=too-few-public-methods
@ -149,7 +164,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
# pylint: disable=too-many-arguments
def __init__(self, server_address, request_handler_class,
hass, api_password, development, no_password_set):
hass, api_password, development, no_password_set,
sessions_enabled):
super().__init__(server_address, request_handler_class)
self.server_address = server_address
@ -158,6 +174,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
self.development = development
self.no_password_set = no_password_set
self.paths = []
self.sessions = SessionStore(sessions_enabled)
# We will lazy init this one if needed
self.event_forwarder = None
@ -197,6 +214,11 @@ class RequestHandler(SimpleHTTPRequestHandler):
server_version = "HomeAssistant/1.0"
def __init__(self, req, client_addr, server):
""" Contructor, call the base constructor and set up session """
self._session = None
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
def _handle_request(self, method): # pylint: disable=too-many-branches
""" Does some common checks and calls appropriate method. """
url = urlparse(self.path)
@ -225,6 +247,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
return
self._session = self.get_session()
if self.server.no_password_set:
api_password = self.server.api_password
else:
@ -233,6 +256,10 @@ class RequestHandler(SimpleHTTPRequestHandler):
if not api_password and DATA_API_PASSWORD in data:
api_password = data[DATA_API_PASSWORD]
if not api_password and self._session is not None:
api_password = self._session.cookie_values.get(
CONF_API_PASSWORD)
if '_METHOD' in data:
method = data.pop('_METHOD')
@ -271,6 +298,10 @@ class RequestHandler(SimpleHTTPRequestHandler):
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
else:
if self._session is None and require_auth:
self._session = self.server.sessions.create(
api_password)
handle_request_method(self, path_match, data)
elif path_matched_but_not_method:
@ -313,6 +344,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
if location:
self.send_header('Location', location)
self.set_session_cookie_header()
self.end_headers()
if data is not None:
@ -342,6 +375,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
self.set_cache_header()
self.set_session_cookie_header()
if do_gzip:
gzip_data = gzip.compress(inp.read())
@ -377,3 +411,113 @@ class RequestHandler(SimpleHTTPRequestHandler):
self.send_header(
HTTP_HEADER_EXPIRES,
self.date_time_string(time.time()+cache_time))
def set_session_cookie_header(self):
""" Add the header for the session cookie """
if self.server.sessions.enabled and self._session is not None:
existing_sess_id = self.get_current_session_id()
if existing_sess_id != self._session.session_id:
self.send_header(
'Set-Cookie',
SESSION_KEY+'='+self._session.session_id)
def get_session(self):
""" Get the requested session object from cookie value """
if self.server.sessions.enabled is not True:
return None
session_id = self.get_current_session_id()
if session_id is not None:
session = self.server.sessions.get(session_id)
if session is not None:
session.reset_expiry()
return session
return None
def get_current_session_id(self):
"""
Extracts the current session id from the
cookie or returns None if not set
"""
cookie = cookies.SimpleCookie()
if self.headers.get('Cookie', None) is not None:
cookie.load(self.headers.get("Cookie"))
if cookie.get(SESSION_KEY, False):
return cookie[SESSION_KEY].value
return None
class ServerSession:
""" A very simple session class """
def __init__(self, session_id):
""" Set up the expiry time on creation """
self._expiry = 0
self.reset_expiry()
self.cookie_values = {}
self.session_id = session_id
def reset_expiry(self):
""" Resets the expiry based on current time """
self._expiry = date_util.utcnow() + timedelta(
seconds=SESSION_TIMEOUT_SECONDS)
@property
def is_expired(self):
""" Return true if the session is expired based on the expiry time """
return self._expiry < date_util.utcnow()
class SessionStore:
""" Responsible for storing and retrieving http sessions """
def __init__(self, enabled=True):
""" Set up the session store """
self._sessions = {}
self.enabled = enabled
self.session_lock = threading.RLock()
@Throttle(MIN_SEC_SESSION_CLEARING)
def remove_expired(self):
""" Remove any expired sessions. """
if self.session_lock.acquire(False):
try:
keys = []
for key in self._sessions.keys():
keys.append(key)
for key in keys:
if self._sessions[key].is_expired:
del self._sessions[key]
_LOGGER.info("Cleared expired session %s", key)
finally:
self.session_lock.release()
def add(self, key, session):
""" Add a new session to the list of tracked sessions """
self.remove_expired()
with self.session_lock:
self._sessions[key] = session
def get(self, key):
""" get a session by key """
self.remove_expired()
session = self._sessions.get(key, None)
if session is not None and session.is_expired:
return None
return session
def create(self, api_password):
""" Creates a new session and adds it to the sessions """
if self.enabled is not True:
return None
chars = string.ascii_letters + string.digits
session_id = ''.join([random.choice(chars) for i in range(20)])
session = ServerSession(session_id)
session.cookie_values[CONF_API_PASSWORD] = api_password
self.add(session_id, session)
return session