Add tests for API.stream

This commit is contained in:
Paulus Schoutsen 2015-12-14 23:20:43 -08:00
parent 2e0042adb0
commit 027b891052
3 changed files with 84 additions and 15 deletions

View File

@ -125,7 +125,6 @@ def _handle_get_api_stream(handler, path_match, data):
try: try:
wfile.write(msg.encode("UTF-8")) wfile.write(msg.encode("UTF-8"))
wfile.flush() wfile.flush()
handler.server.sessions.extend_validation(session_id)
except (IOError, ValueError): except (IOError, ValueError):
# IOError: socket errors # IOError: socket errors
# ValueError: raised when 'I/O operation on closed file' # ValueError: raised when 'I/O operation on closed file'
@ -135,14 +134,14 @@ def _handle_get_api_stream(handler, path_match, data):
""" Forwards events to the open request. """ """ Forwards events to the open request. """
nonlocal gracefully_closed nonlocal gracefully_closed
if block.is_set() or event.event_type == EVENT_TIME_CHANGED or \ if block.is_set() or event.event_type == EVENT_TIME_CHANGED:
restrict and event.event_type not in restrict:
return return
elif event.event_type == EVENT_HOMEASSISTANT_STOP: elif event.event_type == EVENT_HOMEASSISTANT_STOP:
gracefully_closed = True gracefully_closed = True
block.set() block.set()
return return
handler.server.sessions.extend_validation(session_id)
write_message(json.dumps(event, cls=rem.JSONEncoder)) write_message(json.dumps(event, cls=rem.JSONEncoder))
handler.send_response(HTTP_OK) handler.send_response(HTTP_OK)
@ -150,7 +149,11 @@ def _handle_get_api_stream(handler, path_match, data):
session_id = handler.set_session_cookie_header() session_id = handler.set_session_cookie_header()
handler.end_headers() handler.end_headers()
hass.bus.listen(MATCH_ALL, forward_events) if restrict:
for event in restrict:
hass.bus.listen(event, forward_events)
else:
hass.bus.listen(MATCH_ALL, forward_events)
while True: while True:
write_message(STREAM_PING_PAYLOAD) write_message(STREAM_PING_PAYLOAD)
@ -164,7 +167,11 @@ def _handle_get_api_stream(handler, path_match, data):
_LOGGER.info("Found broken event stream to %s, cleaning up", _LOGGER.info("Found broken event stream to %s, cleaning up",
handler.client_address[0]) handler.client_address[0])
hass.bus.remove_listener(MATCH_ALL, forward_events) if restrict:
for event in restrict:
hass.bus.remove_listener(event, forward_events)
else:
hass.bus.remove_listener(MATCH_ALL, forward_events)
def _handle_get_api_config(handler, path_match, data): def _handle_get_api_config(handler, path_match, data):

View File

@ -359,13 +359,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
def set_session_cookie_header(self): def set_session_cookie_header(self):
""" Add the header for the session cookie and return session id. """ """ Add the header for the session cookie and return session id. """
if not self.authenticated: if not self.authenticated:
return return None
session_id = self.get_cookie_session_id() session_id = self.get_cookie_session_id()
if session_id is not None: if session_id is not None:
self.server.sessions.extend_validation(session_id) self.server.sessions.extend_validation(session_id)
return return session_id
self.send_header( self.send_header(
'Set-Cookie', 'Set-Cookie',
@ -422,10 +422,10 @@ def session_valid_time():
class SessionStore(object): class SessionStore(object):
""" Responsible for storing and retrieving http sessions """ """ Responsible for storing and retrieving http sessions """
def __init__(self, enabled=True): def __init__(self):
""" Set up the session store """ """ Set up the session store """
self._sessions = {} self._sessions = {}
self.lock = threading.RLock() self._lock = threading.RLock()
@util.Throttle(SESSION_CLEAR_INTERVAL) @util.Throttle(SESSION_CLEAR_INTERVAL)
def _remove_expired(self): def _remove_expired(self):
@ -437,7 +437,7 @@ class SessionStore(object):
def is_valid(self, key): def is_valid(self, key):
""" Return True if a valid session is given. """ """ Return True if a valid session is given. """
with self.lock: with self._lock:
self._remove_expired() self._remove_expired()
return (key in self._sessions and return (key in self._sessions and
@ -445,17 +445,19 @@ class SessionStore(object):
def extend_validation(self, key): def extend_validation(self, key):
""" Extend a session validation time. """ """ Extend a session validation time. """
with self.lock: with self._lock:
if key not in self._sessions:
return
self._sessions[key] = session_valid_time() self._sessions[key] = session_valid_time()
def destroy(self, key): def destroy(self, key):
""" Destroy a session by key. """ """ Destroy a session by key. """
with self.lock: with self._lock:
self._sessions.pop(key, None) self._sessions.pop(key, None)
def create(self): def create(self):
""" Creates a new session. """ """ Creates a new session. """
with self.lock: with self._lock:
session_id = util.get_random_string(20) session_id = util.get_random_string(20)
while session_id in self._sessions: while session_id in self._sessions:

View File

@ -5,10 +5,11 @@ tests.test_component_http
Tests Home Assistant HTTP component does what it should do. Tests Home Assistant HTTP component does what it should do.
""" """
# pylint: disable=protected-access,too-many-public-methods # pylint: disable=protected-access,too-many-public-methods
import unittest from contextlib import closing
import json import json
from unittest.mock import patch
import tempfile import tempfile
import unittest
from unittest.mock import patch
import requests import requests
@ -415,3 +416,62 @@ class TestAPI(unittest.TestCase):
}), }),
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(200, req.status_code) self.assertEqual(200, req.status_code)
def test_stream(self):
listen_count = self._listen_count()
with closing(requests.get(_url(const.URL_API_STREAM),
stream=True, headers=HA_HEADERS)) as req:
self.assertEqual(listen_count + 1, self._listen_count())
data = self._stream_next_event(req)
self.assertEqual('ping', data)
hass.bus.fire('test_event')
hass.pool.block_till_done()
data = self._stream_next_event(req)
self.assertEqual('test_event', data['event_type'])
def test_stream_with_restricted(self):
listen_count = self._listen_count()
with closing(requests.get(_url(const.URL_API_STREAM),
data=json.dumps({
'restrict': 'test_event1,test_event3'}),
stream=True, headers=HA_HEADERS)) as req:
self.assertEqual(listen_count + 2, self._listen_count())
data = self._stream_next_event(req)
self.assertEqual('ping', data)
hass.bus.fire('test_event1')
hass.pool.block_till_done()
hass.bus.fire('test_event2')
hass.pool.block_till_done()
hass.bus.fire('test_event3')
hass.pool.block_till_done()
data = self._stream_next_event(req)
self.assertEqual('test_event1', data['event_type'])
data = self._stream_next_event(req)
self.assertEqual('test_event3', data['event_type'])
def _stream_next_event(self, stream):
data = b''
last_new_line = False
for dat in stream.iter_content(1):
if dat == b'\n' and last_new_line:
break
data += dat
last_new_line = dat == b'\n'
conv = data.decode('utf-8').strip()[6:]
return conv if conv == 'ping' else json.loads(conv)
def _listen_count(self):
""" Return number of event listeners. """
return sum(hass.bus.listeners.values())