diff --git a/config/configuration.yaml.example b/config/configuration.yaml.example index 524078b7765..c99f760f21f 100644 --- a/config/configuration.yaml.example +++ b/config/configuration.yaml.example @@ -68,7 +68,7 @@ device_sun_light_trigger: # A comma separated list of states that have to be tracked as a single group # Grouped states should share the same type of states (ON/OFF or HOME/NOT_HOME) -group: +group: living_room: - light.Bowl - light.Ceiling diff --git a/homeassistant/components/http.py b/homeassistant/components/http.py index e9ffea0d592..c1c0899c9ef 100644 --- a/homeassistant/components/http.py +++ b/homeassistant/components/http.py @@ -77,7 +77,12 @@ import logging import time import gzip import os +import random +import string +from datetime import timedelta +from homeassistant.util import Throttle from http.server import SimpleHTTPRequestHandler, HTTPServer +from http import cookies from socketserver import ThreadingMixIn from urllib.parse import urlparse, parse_qs @@ -90,6 +95,7 @@ from homeassistant.const import ( HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY) import homeassistant.remote as rem import homeassistant.util as util +import homeassistant.util.dt as date_util import homeassistant.bootstrap as bootstrap DOMAIN = "http" @@ -99,9 +105,15 @@ CONF_API_PASSWORD = "api_password" CONF_SERVER_HOST = "server_host" CONF_SERVER_PORT = "server_port" CONF_DEVELOPMENT = "development" +CONF_SESSIONS_ENABLED = "sessions_enabled" DATA_API_PASSWORD = 'api_password' +# Throttling time in seconds for expired sessions check +MIN_SEC_SESSION_CLEARING = timedelta(seconds=20) +SESSION_TIMEOUT_SECONDS = 1800 +SESSION_KEY = 'sessionId' + _LOGGER = logging.getLogger(__name__) @@ -125,9 +137,11 @@ def setup(hass, config=None): development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1" + sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True) + server = HomeAssistantHTTPServer( (server_host, server_port), RequestHandler, hass, api_password, - development, no_password_set) + development, no_password_set, sessions_enabled) hass.bus.listen_once( ha.EVENT_HOMEASSISTANT_START, @@ -140,6 +154,7 @@ def setup(hass, config=None): return True +# pylint: disable=too-many-instance-attributes class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): """ Handle HTTP requests in a threaded fashion. """ # pylint: disable=too-few-public-methods @@ -149,7 +164,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): # pylint: disable=too-many-arguments def __init__(self, server_address, request_handler_class, - hass, api_password, development, no_password_set): + hass, api_password, development, no_password_set, + sessions_enabled): super().__init__(server_address, request_handler_class) self.server_address = server_address @@ -158,6 +174,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer): self.development = development self.no_password_set = no_password_set self.paths = [] + self.sessions = SessionStore(sessions_enabled) # We will lazy init this one if needed self.event_forwarder = None @@ -197,6 +214,11 @@ class RequestHandler(SimpleHTTPRequestHandler): server_version = "HomeAssistant/1.0" + def __init__(self, req, client_addr, server): + """ Contructor, call the base constructor and set up session """ + self._session = None + SimpleHTTPRequestHandler.__init__(self, req, client_addr, server) + def _handle_request(self, method): # pylint: disable=too-many-branches """ Does some common checks and calls appropriate method. """ url = urlparse(self.path) @@ -225,6 +247,7 @@ class RequestHandler(SimpleHTTPRequestHandler): "Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY) return + self._session = self.get_session() if self.server.no_password_set: api_password = self.server.api_password else: @@ -233,6 +256,10 @@ class RequestHandler(SimpleHTTPRequestHandler): if not api_password and DATA_API_PASSWORD in data: api_password = data[DATA_API_PASSWORD] + if not api_password and self._session is not None: + api_password = self._session.cookie_values.get( + CONF_API_PASSWORD) + if '_METHOD' in data: method = data.pop('_METHOD') @@ -271,6 +298,10 @@ class RequestHandler(SimpleHTTPRequestHandler): "API password missing or incorrect.", HTTP_UNAUTHORIZED) else: + if self._session is None and require_auth: + self._session = self.server.sessions.create( + api_password) + handle_request_method(self, path_match, data) elif path_matched_but_not_method: @@ -313,6 +344,8 @@ class RequestHandler(SimpleHTTPRequestHandler): if location: self.send_header('Location', location) + self.set_session_cookie_header() + self.end_headers() if data is not None: @@ -342,6 +375,7 @@ class RequestHandler(SimpleHTTPRequestHandler): self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type) self.set_cache_header() + self.set_session_cookie_header() if do_gzip: gzip_data = gzip.compress(inp.read()) @@ -377,3 +411,113 @@ class RequestHandler(SimpleHTTPRequestHandler): self.send_header( HTTP_HEADER_EXPIRES, self.date_time_string(time.time()+cache_time)) + + def set_session_cookie_header(self): + """ Add the header for the session cookie """ + if self.server.sessions.enabled and self._session is not None: + existing_sess_id = self.get_current_session_id() + + if existing_sess_id != self._session.session_id: + self.send_header( + 'Set-Cookie', + SESSION_KEY+'='+self._session.session_id) + + def get_session(self): + """ Get the requested session object from cookie value """ + if self.server.sessions.enabled is not True: + return None + + session_id = self.get_current_session_id() + if session_id is not None: + session = self.server.sessions.get(session_id) + if session is not None: + session.reset_expiry() + return session + + return None + + def get_current_session_id(self): + """ + Extracts the current session id from the + cookie or returns None if not set + """ + cookie = cookies.SimpleCookie() + + if self.headers.get('Cookie', None) is not None: + cookie.load(self.headers.get("Cookie")) + + if cookie.get(SESSION_KEY, False): + return cookie[SESSION_KEY].value + + return None + + +class ServerSession: + """ A very simple session class """ + def __init__(self, session_id): + """ Set up the expiry time on creation """ + self._expiry = 0 + self.reset_expiry() + self.cookie_values = {} + self.session_id = session_id + + def reset_expiry(self): + """ Resets the expiry based on current time """ + self._expiry = date_util.utcnow() + timedelta( + seconds=SESSION_TIMEOUT_SECONDS) + + @property + def is_expired(self): + """ Return true if the session is expired based on the expiry time """ + return self._expiry < date_util.utcnow() + + +class SessionStore: + """ Responsible for storing and retrieving http sessions """ + def __init__(self, enabled=True): + """ Set up the session store """ + self._sessions = {} + self.enabled = enabled + self.session_lock = threading.RLock() + + @Throttle(MIN_SEC_SESSION_CLEARING) + def remove_expired(self): + """ Remove any expired sessions. """ + if self.session_lock.acquire(False): + try: + keys = [] + for key in self._sessions.keys(): + keys.append(key) + + for key in keys: + if self._sessions[key].is_expired: + del self._sessions[key] + _LOGGER.info("Cleared expired session %s", key) + finally: + self.session_lock.release() + + def add(self, key, session): + """ Add a new session to the list of tracked sessions """ + self.remove_expired() + with self.session_lock: + self._sessions[key] = session + + def get(self, key): + """ get a session by key """ + self.remove_expired() + session = self._sessions.get(key, None) + if session is not None and session.is_expired: + return None + return session + + def create(self, api_password): + """ Creates a new session and adds it to the sessions """ + if self.enabled is not True: + return None + + chars = string.ascii_letters + string.digits + session_id = ''.join([random.choice(chars) for i in range(20)]) + session = ServerSession(session_id) + session.cookie_values[CONF_API_PASSWORD] = api_password + self.add(session_id, session) + return session