diff --git a/homeassistant/components/http.py b/homeassistant/components/http.py index e9ffea0d592..47609c748be 100644 --- a/homeassistant/components/http.py +++ b/homeassistant/components/http.py @@ -77,7 +77,12 @@ import logging import time import gzip import os +import random +import string +from datetime import timedelta +from homeassistant.util import Throttle from http.server import SimpleHTTPRequestHandler, HTTPServer +from http import cookies from socketserver import ThreadingMixIn from urllib.parse import urlparse, parse_qs @@ -90,6 +95,7 @@ from homeassistant.const import ( HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY) import homeassistant.remote as rem import homeassistant.util as util +import homeassistant.util.dt as date_util import homeassistant.bootstrap as bootstrap DOMAIN = "http" @@ -99,9 +105,16 @@ CONF_API_PASSWORD = "api_password" CONF_SERVER_HOST = "server_host" CONF_SERVER_PORT = "server_port" CONF_DEVELOPMENT = "development" +CONF_SESSIONS_ENABLED = "sessions_enabled" DATA_API_PASSWORD = 'api_password' +# Throttling time in seconds for expired sessions check +MIN_SEC_SESSION_CLEARING = timedelta(seconds=20) +SESSION_TIMEOUT_SECONDS = 1800 +SESSION_LOCK = threading.RLock() +SESSION_KEY = 'sessionId' + _LOGGER = logging.getLogger(__name__) @@ -125,9 +138,11 @@ def setup(hass, config=None): development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1" + sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False) + server = HomeAssistantHTTPServer( (server_host, server_port), RequestHandler, hass, api_password, - development, no_password_set) + development, no_password_set, sessions_enabled) hass.bus.listen_once( ha.EVENT_HOMEASSISTANT_START, @@ -139,7 +154,7 @@ def setup(hass, config=None): return True - +# pylint: disable=too-many-instance-attributes class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): """ Handle HTTP requests in a threaded fashion. """ # pylint: disable=too-few-public-methods @@ -149,7 +164,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): # pylint: disable=too-many-arguments def __init__(self, server_address, request_handler_class, - hass, api_password, development, no_password_set): + hass, api_password, development, no_password_set, + sessions_enabled): super().__init__(server_address, request_handler_class) self.server_address = server_address @@ -158,6 +174,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): self.development = development self.no_password_set = no_password_set self.paths = [] + self.sessions_enabled = sessions_enabled + self._sessions = {} # We will lazy init this one if needed self.event_forwarder = None @@ -185,6 +203,51 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): """ Regitsters a path wit the server. """ self.paths.append((method, url, callback, require_auth)) + @Throttle(MIN_SEC_SESSION_CLEARING) + def remove_expired_sessions(self): + """ Reemove any expired sessions. """ + if SESSION_LOCK.acquire(False): + try: + keys = [] + for key in self._sessions.keys(): + keys.append(key) + + for key in keys: + if self._sessions[key].is_expired: + del self._sessions[key] + _LOGGER.info("Cleared expired session %s", key) + finally: + SESSION_LOCK.release() + + def add_session(self, key, session): + """ Add a new session to the list of tracked sessions """ + self.remove_expired_sessions() + try: + SESSION_LOCK.acquire() + self._sessions[key] = session + finally: + SESSION_LOCK.release() + + def get_session(self, key): + """ get a session by key """ + self.remove_expired_sessions() + session = self._sessions.get(key, None) + if session is not None and session.is_expired: + return None + return session + + def create_session(self, api_password): + """ Creates a new session and adds it to the sessions """ + if self.sessions_enabled is not True: + return None + + chars = string.ascii_letters + string.digits + session_id = ''.join([random.choice(chars) for i in range(20)]) + session = ServerSession(session_id) + session.cookie_values['api_password'] = api_password + self.add_session(session_id, session) + return session + # pylint: disable=too-many-public-methods,too-many-locals class RequestHandler(SimpleHTTPRequestHandler): @@ -197,6 +260,11 @@ class RequestHandler(SimpleHTTPRequestHandler): server_version = "HomeAssistant/1.0" + def __init__(self, req, client_addr, server): + """ Contructor, call the base constructor and set up session """ + SimpleHTTPRequestHandler.__init__(self, req, client_addr, server) + self._session = None + def _handle_request(self, method): # pylint: disable=too-many-branches """ Does some common checks and calls appropriate method. """ url = urlparse(self.path) @@ -225,6 +293,7 @@ class RequestHandler(SimpleHTTPRequestHandler): "Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY) return + self._session = self.get_session() if self.server.no_password_set: api_password = self.server.api_password else: @@ -233,6 +302,9 @@ class RequestHandler(SimpleHTTPRequestHandler): if not api_password and DATA_API_PASSWORD in data: api_password = data[DATA_API_PASSWORD] + if not api_password and self._session is not None: + api_password = self._session.cookie_values.get('api_password') + if '_METHOD' in data: method = data.pop('_METHOD') @@ -271,6 +343,9 @@ class RequestHandler(SimpleHTTPRequestHandler): "API password missing or incorrect.", HTTP_UNAUTHORIZED) else: + if self._session is None and require_auth: + self._session = self.server.create_session(api_password) + handle_request_method(self, path_match, data) elif path_matched_but_not_method: @@ -313,6 +388,8 @@ class RequestHandler(SimpleHTTPRequestHandler): if location: self.send_header('Location', location) + self.set_session_cookie_header() + self.end_headers() if data is not None: @@ -342,6 +419,7 @@ class RequestHandler(SimpleHTTPRequestHandler): self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type) self.set_cache_header() + self.set_session_cookie_header() if do_gzip: gzip_data = gzip.compress(inp.read()) @@ -377,3 +455,57 @@ class RequestHandler(SimpleHTTPRequestHandler): self.send_header( HTTP_HEADER_EXPIRES, self.date_time_string(time.time()+cache_time)) + + def set_session_cookie_header(self): + """ Add the header for the session cookie """ + if self.server.sessions_enabled and self._session is not None: + cookie = cookies.SimpleCookie() + existing_sess_id = None + + if self.headers.get('Cookie', None) is not None: + cookie.load(self.headers.get('Cookie')) + if cookie.get(SESSION_KEY, False): + existing_sess_id = cookie[SESSION_KEY].value + + if existing_sess_id != self._session.session_id: + self.send_header( + 'Set-Cookie', + SESSION_KEY+'='+self._session.session_id) + + def get_session(self): + """ Get the requested session object from cookie value """ + if self.server.sessions_enabled is not True: + return None + + cookie = cookies.SimpleCookie() + + if self.headers.get('Cookie', None) is not None: + cookie.load(self.headers.get("Cookie")) + + if cookie.get(SESSION_KEY, False): + session = self.server.get_session(cookie[SESSION_KEY].value) + if session is not None: + session.reset_expiry() + return session + else: + return None + + +class ServerSession: + """ A very simple session class """ + def __init__(self, session_id): + """ Set up the expiry time on creation """ + self._expiry = 0 + self.reset_expiry() + self.cookie_values = {} + self.session_id = session_id + + def reset_expiry(self): + """ Resets the expiry based on current time """ + self._expiry = date_util.utcnow() + timedelta( + seconds=SESSION_TIMEOUT_SECONDS) + + @property + def is_expired(self): + """ Return true if the session is expired based on the expiry time """ + return self._expiry < date_util.utcnow()