Add tests for API.stream

This commit is contained in:
Paulus Schoutsen 2015-12-14 23:20:43 -08:00
parent 2e0042adb0
commit 027b891052
3 changed files with 84 additions and 15 deletions

View File

@ -125,7 +125,6 @@ def _handle_get_api_stream(handler, path_match, data):
try:
wfile.write(msg.encode("UTF-8"))
wfile.flush()
handler.server.sessions.extend_validation(session_id)
except (IOError, ValueError):
# IOError: socket errors
# ValueError: raised when 'I/O operation on closed file'
@ -135,14 +134,14 @@ def _handle_get_api_stream(handler, path_match, data):
""" Forwards events to the open request. """
nonlocal gracefully_closed
if block.is_set() or event.event_type == EVENT_TIME_CHANGED or \
restrict and event.event_type not in restrict:
if block.is_set() or event.event_type == EVENT_TIME_CHANGED:
return
elif event.event_type == EVENT_HOMEASSISTANT_STOP:
gracefully_closed = True
block.set()
return
handler.server.sessions.extend_validation(session_id)
write_message(json.dumps(event, cls=rem.JSONEncoder))
handler.send_response(HTTP_OK)
@ -150,7 +149,11 @@ def _handle_get_api_stream(handler, path_match, data):
session_id = handler.set_session_cookie_header()
handler.end_headers()
hass.bus.listen(MATCH_ALL, forward_events)
if restrict:
for event in restrict:
hass.bus.listen(event, forward_events)
else:
hass.bus.listen(MATCH_ALL, forward_events)
while True:
write_message(STREAM_PING_PAYLOAD)
@ -164,7 +167,11 @@ def _handle_get_api_stream(handler, path_match, data):
_LOGGER.info("Found broken event stream to %s, cleaning up",
handler.client_address[0])
hass.bus.remove_listener(MATCH_ALL, forward_events)
if restrict:
for event in restrict:
hass.bus.remove_listener(event, forward_events)
else:
hass.bus.remove_listener(MATCH_ALL, forward_events)
def _handle_get_api_config(handler, path_match, data):

View File

@ -359,13 +359,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
def set_session_cookie_header(self):
""" Add the header for the session cookie and return session id. """
if not self.authenticated:
return
return None
session_id = self.get_cookie_session_id()
if session_id is not None:
self.server.sessions.extend_validation(session_id)
return
return session_id
self.send_header(
'Set-Cookie',
@ -422,10 +422,10 @@ def session_valid_time():
class SessionStore(object):
""" Responsible for storing and retrieving http sessions """
def __init__(self, enabled=True):
def __init__(self):
""" Set up the session store """
self._sessions = {}
self.lock = threading.RLock()
self._lock = threading.RLock()
@util.Throttle(SESSION_CLEAR_INTERVAL)
def _remove_expired(self):
@ -437,7 +437,7 @@ class SessionStore(object):
def is_valid(self, key):
""" Return True if a valid session is given. """
with self.lock:
with self._lock:
self._remove_expired()
return (key in self._sessions and
@ -445,17 +445,19 @@ class SessionStore(object):
def extend_validation(self, key):
""" Extend a session validation time. """
with self.lock:
with self._lock:
if key not in self._sessions:
return
self._sessions[key] = session_valid_time()
def destroy(self, key):
""" Destroy a session by key. """
with self.lock:
with self._lock:
self._sessions.pop(key, None)
def create(self):
""" Creates a new session. """
with self.lock:
with self._lock:
session_id = util.get_random_string(20)
while session_id in self._sessions:

View File

@ -5,10 +5,11 @@ tests.test_component_http
Tests Home Assistant HTTP component does what it should do.
"""
# pylint: disable=protected-access,too-many-public-methods
import unittest
from contextlib import closing
import json
from unittest.mock import patch
import tempfile
import unittest
from unittest.mock import patch
import requests
@ -415,3 +416,62 @@ class TestAPI(unittest.TestCase):
}),
headers=HA_HEADERS)
self.assertEqual(200, req.status_code)
def test_stream(self):
listen_count = self._listen_count()
with closing(requests.get(_url(const.URL_API_STREAM),
stream=True, headers=HA_HEADERS)) as req:
self.assertEqual(listen_count + 1, self._listen_count())
data = self._stream_next_event(req)
self.assertEqual('ping', data)
hass.bus.fire('test_event')
hass.pool.block_till_done()
data = self._stream_next_event(req)
self.assertEqual('test_event', data['event_type'])
def test_stream_with_restricted(self):
listen_count = self._listen_count()
with closing(requests.get(_url(const.URL_API_STREAM),
data=json.dumps({
'restrict': 'test_event1,test_event3'}),
stream=True, headers=HA_HEADERS)) as req:
self.assertEqual(listen_count + 2, self._listen_count())
data = self._stream_next_event(req)
self.assertEqual('ping', data)
hass.bus.fire('test_event1')
hass.pool.block_till_done()
hass.bus.fire('test_event2')
hass.pool.block_till_done()
hass.bus.fire('test_event3')
hass.pool.block_till_done()
data = self._stream_next_event(req)
self.assertEqual('test_event1', data['event_type'])
data = self._stream_next_event(req)
self.assertEqual('test_event3', data['event_type'])
def _stream_next_event(self, stream):
data = b''
last_new_line = False
for dat in stream.iter_content(1):
if dat == b'\n' and last_new_line:
break
data += dat
last_new_line = dat == b'\n'
conv = data.decode('utf-8').strip()[6:]
return conv if conv == 'ping' else json.loads(conv)
def _listen_count(self):
""" Return number of event listeners. """
return sum(hass.bus.listeners.values())