mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 04:37:06 +00:00
Refactored session handling into a separate class
This commit is contained in:
parent
8431fd822f
commit
80f0c42844
@ -19,7 +19,6 @@ http:
|
|||||||
api_password: mypass
|
api_password: mypass
|
||||||
# Set to 1 to enable development mode
|
# Set to 1 to enable development mode
|
||||||
# development: 1
|
# development: 1
|
||||||
# sessions_enabled: True
|
|
||||||
|
|
||||||
light:
|
light:
|
||||||
# platform: hue
|
# platform: hue
|
||||||
|
@ -112,7 +112,6 @@ DATA_API_PASSWORD = 'api_password'
|
|||||||
# Throttling time in seconds for expired sessions check
|
# Throttling time in seconds for expired sessions check
|
||||||
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
MIN_SEC_SESSION_CLEARING = timedelta(seconds=20)
|
||||||
SESSION_TIMEOUT_SECONDS = 1800
|
SESSION_TIMEOUT_SECONDS = 1800
|
||||||
SESSION_LOCK = threading.RLock()
|
|
||||||
SESSION_KEY = 'sessionId'
|
SESSION_KEY = 'sessionId'
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -138,7 +137,7 @@ def setup(hass, config=None):
|
|||||||
|
|
||||||
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
development = str(config[DOMAIN].get(CONF_DEVELOPMENT, "")) == "1"
|
||||||
|
|
||||||
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, False)
|
sessions_enabled = config[DOMAIN].get(CONF_SESSIONS_ENABLED, True)
|
||||||
|
|
||||||
server = HomeAssistantHTTPServer(
|
server = HomeAssistantHTTPServer(
|
||||||
(server_host, server_port), RequestHandler, hass, api_password,
|
(server_host, server_port), RequestHandler, hass, api_password,
|
||||||
@ -175,8 +174,7 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||||||
self.development = development
|
self.development = development
|
||||||
self.no_password_set = no_password_set
|
self.no_password_set = no_password_set
|
||||||
self.paths = []
|
self.paths = []
|
||||||
self.sessions_enabled = sessions_enabled
|
self.sessions = SessionStore(sessions_enabled)
|
||||||
self._sessions = {}
|
|
||||||
|
|
||||||
# We will lazy init this one if needed
|
# We will lazy init this one if needed
|
||||||
self.event_forwarder = None
|
self.event_forwarder = None
|
||||||
@ -204,51 +202,6 @@ class HomeAssistantHTTPServer(ThreadingMixIn, HTTPServer):
|
|||||||
""" Regitsters a path wit the server. """
|
""" Regitsters a path wit the server. """
|
||||||
self.paths.append((method, url, callback, require_auth))
|
self.paths.append((method, url, callback, require_auth))
|
||||||
|
|
||||||
@Throttle(MIN_SEC_SESSION_CLEARING)
|
|
||||||
def remove_expired_sessions(self):
|
|
||||||
""" Reemove any expired sessions. """
|
|
||||||
if SESSION_LOCK.acquire(False):
|
|
||||||
try:
|
|
||||||
keys = []
|
|
||||||
for key in self._sessions.keys():
|
|
||||||
keys.append(key)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
if self._sessions[key].is_expired:
|
|
||||||
del self._sessions[key]
|
|
||||||
_LOGGER.info("Cleared expired session %s", key)
|
|
||||||
finally:
|
|
||||||
SESSION_LOCK.release()
|
|
||||||
|
|
||||||
def add_session(self, key, session):
|
|
||||||
""" Add a new session to the list of tracked sessions """
|
|
||||||
self.remove_expired_sessions()
|
|
||||||
try:
|
|
||||||
SESSION_LOCK.acquire()
|
|
||||||
self._sessions[key] = session
|
|
||||||
finally:
|
|
||||||
SESSION_LOCK.release()
|
|
||||||
|
|
||||||
def get_session(self, key):
|
|
||||||
""" get a session by key """
|
|
||||||
self.remove_expired_sessions()
|
|
||||||
session = self._sessions.get(key, None)
|
|
||||||
if session is not None and session.is_expired:
|
|
||||||
return None
|
|
||||||
return session
|
|
||||||
|
|
||||||
def create_session(self, api_password):
|
|
||||||
""" Creates a new session and adds it to the sessions """
|
|
||||||
if self.sessions_enabled is not True:
|
|
||||||
return None
|
|
||||||
|
|
||||||
chars = string.ascii_letters + string.digits
|
|
||||||
session_id = ''.join([random.choice(chars) for i in range(20)])
|
|
||||||
session = ServerSession(session_id)
|
|
||||||
session.cookie_values['api_password'] = api_password
|
|
||||||
self.add_session(session_id, session)
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods,too-many-locals
|
# pylint: disable=too-many-public-methods,too-many-locals
|
||||||
class RequestHandler(SimpleHTTPRequestHandler):
|
class RequestHandler(SimpleHTTPRequestHandler):
|
||||||
@ -304,7 +257,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
api_password = data[DATA_API_PASSWORD]
|
api_password = data[DATA_API_PASSWORD]
|
||||||
|
|
||||||
if not api_password and self._session is not None:
|
if not api_password and self._session is not None:
|
||||||
api_password = self._session.cookie_values.get('api_password')
|
api_password = self._session.cookie_values.get(
|
||||||
|
CONF_API_PASSWORD)
|
||||||
|
|
||||||
if '_METHOD' in data:
|
if '_METHOD' in data:
|
||||||
method = data.pop('_METHOD')
|
method = data.pop('_METHOD')
|
||||||
@ -345,7 +299,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if self._session is None and require_auth:
|
if self._session is None and require_auth:
|
||||||
self._session = self.server.create_session(api_password)
|
self._session = self.server.sessions.create_session(
|
||||||
|
api_password)
|
||||||
|
|
||||||
handle_request_method(self, path_match, data)
|
handle_request_method(self, path_match, data)
|
||||||
|
|
||||||
@ -459,14 +414,8 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
|
|
||||||
def set_session_cookie_header(self):
|
def set_session_cookie_header(self):
|
||||||
""" Add the header for the session cookie """
|
""" Add the header for the session cookie """
|
||||||
if self.server.sessions_enabled and self._session is not None:
|
if self.server.sessions.enabled and self._session is not None:
|
||||||
cookie = cookies.SimpleCookie()
|
existing_sess_id = self.get_current_session_id()
|
||||||
existing_sess_id = None
|
|
||||||
|
|
||||||
if self.headers.get('Cookie', None) is not None:
|
|
||||||
cookie.load(self.headers.get('Cookie'))
|
|
||||||
if cookie.get(SESSION_KEY, False):
|
|
||||||
existing_sess_id = cookie[SESSION_KEY].value
|
|
||||||
|
|
||||||
if existing_sess_id != self._session.session_id:
|
if existing_sess_id != self._session.session_id:
|
||||||
self.send_header(
|
self.send_header(
|
||||||
@ -475,20 +424,31 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
|
|
||||||
def get_session(self):
|
def get_session(self):
|
||||||
""" Get the requested session object from cookie value """
|
""" Get the requested session object from cookie value """
|
||||||
if self.server.sessions_enabled is not True:
|
if self.server.sessions.enabled is not True:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
session_id = self.get_current_session_id()
|
||||||
|
if session_id is not None:
|
||||||
|
session = self.server.sessions.get_session(session_id)
|
||||||
|
if session is not None:
|
||||||
|
session.reset_expiry()
|
||||||
|
return session
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_current_session_id(self):
|
||||||
|
"""
|
||||||
|
Extracts the current session id from the
|
||||||
|
cookie or returns None if not set
|
||||||
|
"""
|
||||||
cookie = cookies.SimpleCookie()
|
cookie = cookies.SimpleCookie()
|
||||||
|
|
||||||
if self.headers.get('Cookie', None) is not None:
|
if self.headers.get('Cookie', None) is not None:
|
||||||
cookie.load(self.headers.get("Cookie"))
|
cookie.load(self.headers.get("Cookie"))
|
||||||
|
|
||||||
if cookie.get(SESSION_KEY, False):
|
if cookie.get(SESSION_KEY, False):
|
||||||
session = self.server.get_session(cookie[SESSION_KEY].value)
|
return cookie[SESSION_KEY].value
|
||||||
if session is not None:
|
|
||||||
session.reset_expiry()
|
|
||||||
return session
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -510,3 +470,54 @@ class ServerSession:
|
|||||||
def is_expired(self):
|
def is_expired(self):
|
||||||
""" Return true if the session is expired based on the expiry time """
|
""" Return true if the session is expired based on the expiry time """
|
||||||
return self._expiry < date_util.utcnow()
|
return self._expiry < date_util.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStore:
|
||||||
|
""" Responsible for storing and retrieving http sessions """
|
||||||
|
def __init__(self, enabled=True):
|
||||||
|
""" Set up the session store """
|
||||||
|
self._sessions = {}
|
||||||
|
self.enabled = enabled
|
||||||
|
self.session_lock = threading.RLock()
|
||||||
|
|
||||||
|
@Throttle(MIN_SEC_SESSION_CLEARING)
|
||||||
|
def remove_expired_sessions(self):
|
||||||
|
""" Remove any expired sessions. """
|
||||||
|
if self.session_lock.acquire(False):
|
||||||
|
try:
|
||||||
|
keys = []
|
||||||
|
for key in self._sessions.keys():
|
||||||
|
keys.append(key)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if self._sessions[key].is_expired:
|
||||||
|
del self._sessions[key]
|
||||||
|
_LOGGER.info("Cleared expired session %s", key)
|
||||||
|
finally:
|
||||||
|
self.session_lock.release()
|
||||||
|
|
||||||
|
def add_session(self, key, session):
|
||||||
|
""" Add a new session to the list of tracked sessions """
|
||||||
|
self.remove_expired_sessions()
|
||||||
|
with self.session_lock:
|
||||||
|
self._sessions[key] = session
|
||||||
|
|
||||||
|
def get_session(self, key):
|
||||||
|
""" get a session by key """
|
||||||
|
self.remove_expired_sessions()
|
||||||
|
session = self._sessions.get(key, None)
|
||||||
|
if session is not None and session.is_expired:
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
def create_session(self, api_password):
|
||||||
|
""" Creates a new session and adds it to the sessions """
|
||||||
|
if self.enabled is not True:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chars = string.ascii_letters + string.digits
|
||||||
|
session_id = ''.join([random.choice(chars) for i in range(20)])
|
||||||
|
session = ServerSession(session_id)
|
||||||
|
session.cookie_values[CONF_API_PASSWORD] = api_password
|
||||||
|
self.add_session(session_id, session)
|
||||||
|
return session
|
||||||
|
Loading…
x
Reference in New Issue
Block a user