mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Addd basic http sessions to the http component
This commit is contained in:
parent
8f51741c65
commit
721dc6dae4
@ -77,7 +77,12 @@ import logging
|
||||
import time
|
||||
import gzip
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from datetime import timedelta
|
||||
from homeassistant.util import Throttle
|
||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||
from http import cookies
|
||||
from socketserver import ThreadingMixIn
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
@ -90,6 +95,7 @@ from homeassistant.const import (
|
||||
HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY)
|
||||
import homeassistant.remote as rem
|
||||
import homeassistant.util as util
|
||||
import homeassistant.util.dt as date_util
|
||||
import homeassistant.bootstrap as bootstrap
|
||||
|
||||
DOMAIN = "http"
|
||||
@ -99,9 +105,16 @@ CONF_API_PASSWORD = "api_password"
|
||||
CONF_SERVER_HOST = "server_host"
|
||||
CONF_SERVER_PORT = "server_port"
|
||||
CONF_DEVELOPMENT = "development"
|
||||
CONF_SESSIONS_ENABLED = "sessions_enabled"
|
||||
|
||||
DATA_API_PASSWORD = 'api_password'
|
||||
|
||||
# Throttling time in seconds for expired sessions check
|
||||
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
||||
SESSION_TIMEOUT_SECONDS = 1800
|
||||
SESSION_LOCK = threading.RLock()
|
||||
SESSION_KEY = 'sessionId'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -125,9 +138,11 @@ def setup(hass, config=None):
|
||||
|
||||
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
||||
|
||||
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False)
|
||||
|
||||
server = HomeAssistantHTTPServer(
|
||||
(server_host, server_port), RequestHandler, hass, api_password,
|
||||
development, no_password_set)
|
||||
development, no_password_set, sessions_enabled)
|
||||
|
||||
hass.bus.listen_once(
|
||||
ha.EVENT_HOMEASSISTANT_START,
|
||||
@ -139,7 +154,7 @@ def setup(hass, config=None):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||
""" Handle HTTP requests in a threaded fashion. """
|
||||
# pylint: disable=too-few-public-methods
|
||||
@ -149,7 +164,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(self, server_address, request_handler_class,
|
||||
hass, api_password, development, no_password_set):
|
||||
hass, api_password, development, no_password_set,
|
||||
sessions_enabled):
|
||||
super().__init__(server_address, request_handler_class)
|
||||
|
||||
self.server_address = server_address
|
||||
@ -158,6 +174,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||
self.development = development
|
||||
self.no_password_set = no_password_set
|
||||
self.paths = []
|
||||
self.sessions_enabled = sessions_enabled
|
||||
self._sessions = {}
|
||||
|
||||
# We will lazy init this one if needed
|
||||
self.event_forwarder = None
|
||||
@ -185,6 +203,51 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||
""" Regitsters a path wit the server. """
|
||||
self.paths.append((method, url, callback, require_auth))
|
||||
|
||||
@Throttle(MIN_SEC_SESSION_CLEARING)
|
||||
def remove_expired_sessions(self):
|
||||
""" Reemove any expired sessions. """
|
||||
if SESSION_LOCK.acquire(False):
|
||||
try:
|
||||
keys = []
|
||||
for key in self._sessions.keys():
|
||||
keys.append(key)
|
||||
|
||||
for key in keys:
|
||||
if self._sessions[key].is_expired:
|
||||
del self._sessions[key]
|
||||
_LOGGER.info("Cleared expired session %s", key)
|
||||
finally:
|
||||
SESSION_LOCK.release()
|
||||
|
||||
def add_session(self, key, session):
|
||||
""" Add a new session to the list of tracked sessions """
|
||||
self.remove_expired_sessions()
|
||||
try:
|
||||
SESSION_LOCK.acquire()
|
||||
self._sessions[key] = session
|
||||
finally:
|
||||
SESSION_LOCK.release()
|
||||
|
||||
def get_session(self, key):
|
||||
""" get a session by key """
|
||||
self.remove_expired_sessions()
|
||||
session = self._sessions.get(key, None)
|
||||
if session is not None and session.is_expired:
|
||||
return None
|
||||
return session
|
||||
|
||||
def create_session(self, api_password):
|
||||
""" Creates a new session and adds it to the sessions """
|
||||
if self.sessions_enabled is not True:
|
||||
return None
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
session_id = ''.join([random.choice(chars) for i in range(20)])
|
||||
session = ServerSession(session_id)
|
||||
session.cookie_values['api_password'] = api_password
|
||||
self.add_session(session_id, session)
|
||||
return session
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods,too-many-locals
|
||||
class RequestHandler(SimpleHTTPRequestHandler):
|
||||
@ -197,6 +260,11 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
|
||||
server_version = "HomeAssistant/1.0"
|
||||
|
||||
def __init__(self, req, client_addr, server):
|
||||
""" Contructor, call the base constructor and set up session """
|
||||
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
|
||||
self._session = None
|
||||
|
||||
def _handle_request(self, method): # pylint: disable=too-many-branches
|
||||
""" Does some common checks and calls appropriate method. """
|
||||
url = urlparse(self.path)
|
||||
@ -225,6 +293,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
|
||||
return
|
||||
|
||||
self._session = self.get_session()
|
||||
if self.server.no_password_set:
|
||||
api_password = self.server.api_password
|
||||
else:
|
||||
@ -233,6 +302,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
if not api_password and DATA_API_PASSWORD in data:
|
||||
api_password = data[DATA_API_PASSWORD]
|
||||
|
||||
if not api_password and self._session is not None:
|
||||
api_password = self._session.cookie_values.get('api_password')
|
||||
|
||||
if '_METHOD' in data:
|
||||
method = data.pop('_METHOD')
|
||||
|
||||
@ -271,6 +343,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
|
||||
|
||||
else:
|
||||
if self._session is None and require_auth:
|
||||
self._session = self.server.create_session(api_password)
|
||||
|
||||
handle_request_method(self, path_match, data)
|
||||
|
||||
elif path_matched_but_not_method:
|
||||
@ -313,6 +388,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
if location:
|
||||
self.send_header('Location', location)
|
||||
|
||||
self.set_session_cookie_header()
|
||||
|
||||
self.end_headers()
|
||||
|
||||
if data is not None:
|
||||
@ -342,6 +419,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
|
||||
|
||||
self.set_cache_header()
|
||||
self.set_session_cookie_header()
|
||||
|
||||
if do_gzip:
|
||||
gzip_data = gzip.compress(inp.read())
|
||||
@ -377,3 +455,57 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
self.send_header(
|
||||
HTTP_HEADER_EXPIRES,
|
||||
self.date_time_string(time.time()+cache_time))
|
||||
|
||||
def set_session_cookie_header(self):
|
||||
""" Add the header for the session cookie """
|
||||
if self.server.sessions_enabled and self._session is not None:
|
||||
cookie = cookies.SimpleCookie()
|
||||
existing_sess_id = None
|
||||
|
||||
if self.headers.get('Cookie', None) is not None:
|
||||
cookie.load(self.headers.get('Cookie'))
|
||||
if cookie.get(SESSION_KEY, False):
|
||||
existing_sess_id = cookie[SESSION_KEY].value
|
||||
|
||||
if existing_sess_id != self._session.session_id:
|
||||
self.send_header(
|
||||
'Set-Cookie',
|
||||
SESSION_KEY+'='+self._session.session_id)
|
||||
|
||||
def get_session(self):
|
||||
""" Get the requested session object from cookie value """
|
||||
if self.server.sessions_enabled is not True:
|
||||
return None
|
||||
|
||||
cookie = cookies.SimpleCookie()
|
||||
|
||||
if self.headers.get('Cookie', None) is not None:
|
||||
cookie.load(self.headers.get("Cookie"))
|
||||
|
||||
if cookie.get(SESSION_KEY, False):
|
||||
session = self.server.get_session(cookie[SESSION_KEY].value)
|
||||
if session is not None:
|
||||
session.reset_expiry()
|
||||
return session
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ServerSession:
|
||||
""" A very simple session class """
|
||||
def __init__(self, session_id):
|
||||
""" Set up the expiry time on creation """
|
||||
self._expiry = 0
|
||||
self.reset_expiry()
|
||||
self.cookie_values = {}
|
||||
self.session_id = session_id
|
||||
|
||||
def reset_expiry(self):
|
||||
""" Resets the expiry based on current time """
|
||||
self._expiry = date_util.utcnow() + timedelta(
|
||||
seconds=SESSION_TIMEOUT_SECONDS)
|
||||
|
||||
@property
|
||||
def is_expired(self):
|
||||
""" Return true if the session is expired based on the expiry time """
|
||||
return self._expiry < date_util.utcnow()
|
||||
|
Loading…
x
Reference in New Issue
Block a user