mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Addd basic http sessions to the http component
This commit is contained in:
parent
8f51741c65
commit
721dc6dae4
@ -77,7 +77,12 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
from datetime import timedelta
|
||||||
|
from homeassistant.util import Throttle
|
||||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||||
|
from http import cookies
|
||||||
from socketserver import ThreadingMixIn
|
from socketserver import ThreadingMixIn
|
||||||
from urllib.parse import urlparse, parse_qs
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
@ -90,6 +95,7 @@ from homeassistant.const import (
|
|||||||
HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY)
|
HTTP_NOT_FOUND, HTTP_METHOD_NOT_ALLOWED, HTTP_UNPROCESSABLE_ENTITY)
|
||||||
import homeassistant.remote as rem
|
import homeassistant.remote as rem
|
||||||
import homeassistant.util as util
|
import homeassistant.util as util
|
||||||
|
import homeassistant.util.dt as date_util
|
||||||
import homeassistant.bootstrap as bootstrap
|
import homeassistant.bootstrap as bootstrap
|
||||||
|
|
||||||
DOMAIN = "http"
|
DOMAIN = "http"
|
||||||
@ -99,9 +105,16 @@ CONF_API_PASSWORD = "api_password"
|
|||||||
CONF_SERVER_HOST = "server_host"
|
CONF_SERVER_HOST = "server_host"
|
||||||
CONF_SERVER_PORT = "server_port"
|
CONF_SERVER_PORT = "server_port"
|
||||||
CONF_DEVELOPMENT = "development"
|
CONF_DEVELOPMENT = "development"
|
||||||
|
CONF_SESSIONS_ENABLED = "sessions_enabled"
|
||||||
|
|
||||||
DATA_API_PASSWORD = 'api_password'
|
DATA_API_PASSWORD = 'api_password'
|
||||||
|
|
||||||
|
# Throttling time in seconds for expired sessions check
|
||||||
|
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
||||||
|
SESSION_TIMEOUT_SECONDS = 1800
|
||||||
|
SESSION_LOCK = threading.RLock()
|
||||||
|
SESSION_KEY = 'sessionId'
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -125,9 +138,11 @@ def setup(hass, config=None):
|
|||||||
|
|
||||||
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
||||||
|
|
||||||
|
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False)
|
||||||
|
|
||||||
server = HomeAssistantHTTPServer(
|
server = HomeAssistantHTTPServer(
|
||||||
(server_host, server_port), RequestHandler, hass, api_password,
|
(server_host, server_port), RequestHandler, hass, api_password,
|
||||||
development, no_password_set)
|
development, no_password_set, sessions_enabled)
|
||||||
|
|
||||||
hass.bus.listen_once(
|
hass.bus.listen_once(
|
||||||
ha.EVENT_HOMEASSISTANT_START,
|
ha.EVENT_HOMEASSISTANT_START,
|
||||||
@ -139,7 +154,7 @@ def setup(hass, config=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# pylint: disable=too-many-instance-attributes
|
||||||
class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
||||||
""" Handle HTTP requests in a threaded fashion. """
|
""" Handle HTTP requests in a threaded fashion. """
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
@ -149,7 +164,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def __init__(self, server_address, request_handler_class,
|
def __init__(self, server_address, request_handler_class,
|
||||||
hass, api_password, development, no_password_set):
|
hass, api_password, development, no_password_set,
|
||||||
|
sessions_enabled):
|
||||||
super().__init__(server_address, request_handler_class)
|
super().__init__(server_address, request_handler_class)
|
||||||
|
|
||||||
self.server_address = server_address
|
self.server_address = server_address
|
||||||
@ -158,6 +174,8 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||||||
self.development = development
|
self.development = development
|
||||||
self.no_password_set = no_password_set
|
self.no_password_set = no_password_set
|
||||||
self.paths = []
|
self.paths = []
|
||||||
|
self.sessions_enabled = sessions_enabled
|
||||||
|
self._sessions = {}
|
||||||
|
|
||||||
# We will lazy init this one if needed
|
# We will lazy init this one if needed
|
||||||
self.event_forwarder = None
|
self.event_forwarder = None
|
||||||
@ -185,6 +203,51 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||||||
""" Regitsters a path wit the server. """
|
""" Regitsters a path wit the server. """
|
||||||
self.paths.append((method, url, callback, require_auth))
|
self.paths.append((method, url, callback, require_auth))
|
||||||
|
|
||||||
|
@Throttle(MIN_SEC_SESSION_CLEARING)
|
||||||
|
def remove_expired_sessions(self):
|
||||||
|
""" Reemove any expired sessions. """
|
||||||
|
if SESSION_LOCK.acquire(False):
|
||||||
|
try:
|
||||||
|
keys = []
|
||||||
|
for key in self._sessions.keys():
|
||||||
|
keys.append(key)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if self._sessions[key].is_expired:
|
||||||
|
del self._sessions[key]
|
||||||
|
_LOGGER.info("Cleared expired session %s", key)
|
||||||
|
finally:
|
||||||
|
SESSION_LOCK.release()
|
||||||
|
|
||||||
|
def add_session(self, key, session):
|
||||||
|
""" Add a new session to the list of tracked sessions """
|
||||||
|
self.remove_expired_sessions()
|
||||||
|
try:
|
||||||
|
SESSION_LOCK.acquire()
|
||||||
|
self._sessions[key] = session
|
||||||
|
finally:
|
||||||
|
SESSION_LOCK.release()
|
||||||
|
|
||||||
|
def get_session(self, key):
|
||||||
|
""" get a session by key """
|
||||||
|
self.remove_expired_sessions()
|
||||||
|
session = self._sessions.get(key, None)
|
||||||
|
if session is not None and session.is_expired:
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
def create_session(self, api_password):
|
||||||
|
""" Creates a new session and adds it to the sessions """
|
||||||
|
if self.sessions_enabled is not True:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chars = string.ascii_letters + string.digits
|
||||||
|
session_id = ''.join([random.choice(chars) for i in range(20)])
|
||||||
|
session = ServerSession(session_id)
|
||||||
|
session.cookie_values['api_password'] = api_password
|
||||||
|
self.add_session(session_id, session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods,too-many-locals
|
# pylint: disable=too-many-public-methods,too-many-locals
|
||||||
class RequestHandler(SimpleHTTPRequestHandler):
|
class RequestHandler(SimpleHTTPRequestHandler):
|
||||||
@ -197,6 +260,11 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
|
|
||||||
server_version = "HomeAssistant/1.0"
|
server_version = "HomeAssistant/1.0"
|
||||||
|
|
||||||
|
def __init__(self, req, client_addr, server):
|
||||||
|
""" Contructor, call the base constructor and set up session """
|
||||||
|
SimpleHTTPRequestHandler.__init__(self, req, client_addr, server)
|
||||||
|
self._session = None
|
||||||
|
|
||||||
def _handle_request(self, method): # pylint: disable=too-many-branches
|
def _handle_request(self, method): # pylint: disable=too-many-branches
|
||||||
""" Does some common checks and calls appropriate method. """
|
""" Does some common checks and calls appropriate method. """
|
||||||
url = urlparse(self.path)
|
url = urlparse(self.path)
|
||||||
@ -225,6 +293,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
|
"Error parsing JSON", HTTP_UNPROCESSABLE_ENTITY)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._session = self.get_session()
|
||||||
if self.server.no_password_set:
|
if self.server.no_password_set:
|
||||||
api_password = self.server.api_password
|
api_password = self.server.api_password
|
||||||
else:
|
else:
|
||||||
@ -233,6 +302,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
if not api_password and DATA_API_PASSWORD in data:
|
if not api_password and DATA_API_PASSWORD in data:
|
||||||
api_password = data[DATA_API_PASSWORD]
|
api_password = data[DATA_API_PASSWORD]
|
||||||
|
|
||||||
|
if not api_password and self._session is not None:
|
||||||
|
api_password = self._session.cookie_values.get('api_password')
|
||||||
|
|
||||||
if '_METHOD' in data:
|
if '_METHOD' in data:
|
||||||
method = data.pop('_METHOD')
|
method = data.pop('_METHOD')
|
||||||
|
|
||||||
@ -271,6 +343,9 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
|
"API password missing or incorrect.", HTTP_UNAUTHORIZED)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if self._session is None and require_auth:
|
||||||
|
self._session = self.server.create_session(api_password)
|
||||||
|
|
||||||
handle_request_method(self, path_match, data)
|
handle_request_method(self, path_match, data)
|
||||||
|
|
||||||
elif path_matched_but_not_method:
|
elif path_matched_but_not_method:
|
||||||
@ -313,6 +388,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
if location:
|
if location:
|
||||||
self.send_header('Location', location)
|
self.send_header('Location', location)
|
||||||
|
|
||||||
|
self.set_session_cookie_header()
|
||||||
|
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
@ -342,6 +419,7 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
|
self.send_header(HTTP_HEADER_CONTENT_TYPE, content_type)
|
||||||
|
|
||||||
self.set_cache_header()
|
self.set_cache_header()
|
||||||
|
self.set_session_cookie_header()
|
||||||
|
|
||||||
if do_gzip:
|
if do_gzip:
|
||||||
gzip_data = gzip.compress(inp.read())
|
gzip_data = gzip.compress(inp.read())
|
||||||
@ -377,3 +455,57 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
self.send_header(
|
self.send_header(
|
||||||
HTTP_HEADER_EXPIRES,
|
HTTP_HEADER_EXPIRES,
|
||||||
self.date_time_string(time.time()+cache_time))
|
self.date_time_string(time.time()+cache_time))
|
||||||
|
|
||||||
|
def set_session_cookie_header(self):
|
||||||
|
""" Add the header for the session cookie """
|
||||||
|
if self.server.sessions_enabled and self._session is not None:
|
||||||
|
cookie = cookies.SimpleCookie()
|
||||||
|
existing_sess_id = None
|
||||||
|
|
||||||
|
if self.headers.get('Cookie', None) is not None:
|
||||||
|
cookie.load(self.headers.get('Cookie'))
|
||||||
|
if cookie.get(SESSION_KEY, False):
|
||||||
|
existing_sess_id = cookie[SESSION_KEY].value
|
||||||
|
|
||||||
|
if existing_sess_id != self._session.session_id:
|
||||||
|
self.send_header(
|
||||||
|
'Set-Cookie',
|
||||||
|
SESSION_KEY+'='+self._session.session_id)
|
||||||
|
|
||||||
|
def get_session(self):
|
||||||
|
""" Get the requested session object from cookie value """
|
||||||
|
if self.server.sessions_enabled is not True:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cookie = cookies.SimpleCookie()
|
||||||
|
|
||||||
|
if self.headers.get('Cookie', None) is not None:
|
||||||
|
cookie.load(self.headers.get("Cookie"))
|
||||||
|
|
||||||
|
if cookie.get(SESSION_KEY, False):
|
||||||
|
session = self.server.get_session(cookie[SESSION_KEY].value)
|
||||||
|
if session is not None:
|
||||||
|
session.reset_expiry()
|
||||||
|
return session
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ServerSession:
|
||||||
|
""" A very simple session class """
|
||||||
|
def __init__(self, session_id):
|
||||||
|
""" Set up the expiry time on creation """
|
||||||
|
self._expiry = 0
|
||||||
|
self.reset_expiry()
|
||||||
|
self.cookie_values = {}
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
def reset_expiry(self):
|
||||||
|
""" Resets the expiry based on current time """
|
||||||
|
self._expiry = date_util.utcnow() + timedelta(
|
||||||
|
seconds=SESSION_TIMEOUT_SECONDS)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self):
|
||||||
|
""" Return true if the session is expired based on the expiry time """
|
||||||
|
return self._expiry < date_util.utcnow()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user