From 027b891052889325cd97587fc3fadff3c06cc303 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 14 Dec 2015 23:20:43 -0800 Subject: [PATCH] Add tests for API.stream --- homeassistant/components/api.py | 17 ++++++--- homeassistant/components/http.py | 18 +++++---- tests/components/test_api.py | 64 +++++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index 6d2f9e52a7a..18b743d06ca 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -125,7 +125,6 @@ def _handle_get_api_stream(handler, path_match, data): try: wfile.write(msg.encode("UTF-8")) wfile.flush() - handler.server.sessions.extend_validation(session_id) except (IOError, ValueError): # IOError: socket errors # ValueError: raised when 'I/O operation on closed file' @@ -135,14 +134,14 @@ def _handle_get_api_stream(handler, path_match, data): """ Forwards events to the open request. """ nonlocal gracefully_closed - if block.is_set() or event.event_type == EVENT_TIME_CHANGED or \ - restrict and event.event_type not in restrict: + if block.is_set() or event.event_type == EVENT_TIME_CHANGED: return elif event.event_type == EVENT_HOMEASSISTANT_STOP: gracefully_closed = True block.set() return + handler.server.sessions.extend_validation(session_id) write_message(json.dumps(event, cls=rem.JSONEncoder)) handler.send_response(HTTP_OK) @@ -150,7 +149,11 @@ def _handle_get_api_stream(handler, path_match, data): session_id = handler.set_session_cookie_header() handler.end_headers() - hass.bus.listen(MATCH_ALL, forward_events) + if restrict: + for event in restrict: + hass.bus.listen(event, forward_events) + else: + hass.bus.listen(MATCH_ALL, forward_events) while True: write_message(STREAM_PING_PAYLOAD) @@ -164,7 +167,11 @@ def _handle_get_api_stream(handler, path_match, data): _LOGGER.info("Found broken event stream to %s, cleaning up", handler.client_address[0]) - hass.bus.remove_listener(MATCH_ALL, forward_events) + if restrict: + for event in restrict: + hass.bus.remove_listener(event, forward_events) + else: + hass.bus.remove_listener(MATCH_ALL, forward_events) def _handle_get_api_config(handler, path_match, data): diff --git a/homeassistant/components/http.py b/homeassistant/components/http.py index 81e26aeae5a..7a4e87de5a8 100644 --- a/homeassistant/components/http.py +++ b/homeassistant/components/http.py @@ -359,13 +359,13 @@ class RequestHandler(SimpleHTTPRequestHandler): def set_session_cookie_header(self): """ Add the header for the session cookie and return session id. """ if not self.authenticated: - return + return None session_id = self.get_cookie_session_id() if session_id is not None: self.server.sessions.extend_validation(session_id) - return + return session_id self.send_header( 'Set-Cookie', @@ -422,10 +422,10 @@ def session_valid_time(): class SessionStore(object): """ Responsible for storing and retrieving http sessions """ - def __init__(self, enabled=True): + def __init__(self): """ Set up the session store """ self._sessions = {} - self.lock = threading.RLock() + self._lock = threading.RLock() @util.Throttle(SESSION_CLEAR_INTERVAL) def _remove_expired(self): @@ -437,7 +437,7 @@ class SessionStore(object): def is_valid(self, key): """ Return True if a valid session is given. """ - with self.lock: + with self._lock: self._remove_expired() return (key in self._sessions and @@ -445,17 +445,19 @@ class SessionStore(object): def extend_validation(self, key): """ Extend a session validation time. """ - with self.lock: + with self._lock: + if key not in self._sessions: + return self._sessions[key] = session_valid_time() def destroy(self, key): """ Destroy a session by key. """ - with self.lock: + with self._lock: self._sessions.pop(key, None) def create(self): """ Creates a new session. """ - with self.lock: + with self._lock: session_id = util.get_random_string(20) while session_id in self._sessions: diff --git a/tests/components/test_api.py b/tests/components/test_api.py index cf530c1f301..c8d8fa50ae1 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -5,10 +5,11 @@ tests.test_component_http Tests Home Assistant HTTP component does what it should do. """ # pylint: disable=protected-access,too-many-public-methods -import unittest +from contextlib import closing import json -from unittest.mock import patch import tempfile +import unittest +from unittest.mock import patch import requests @@ -415,3 +416,62 @@ class TestAPI(unittest.TestCase): }), headers=HA_HEADERS) self.assertEqual(200, req.status_code) + + def test_stream(self): + listen_count = self._listen_count() + with closing(requests.get(_url(const.URL_API_STREAM), + stream=True, headers=HA_HEADERS)) as req: + + self.assertEqual(listen_count + 1, self._listen_count()) + data = self._stream_next_event(req) + + self.assertEqual('ping', data) + + hass.bus.fire('test_event') + hass.pool.block_till_done() + + data = self._stream_next_event(req) + + self.assertEqual('test_event', data['event_type']) + + def test_stream_with_restricted(self): + listen_count = self._listen_count() + with closing(requests.get(_url(const.URL_API_STREAM), + data=json.dumps({ + 'restrict': 'test_event1,test_event3'}), + stream=True, headers=HA_HEADERS)) as req: + + self.assertEqual(listen_count + 2, self._listen_count()) + + data = self._stream_next_event(req) + + self.assertEqual('ping', data) + + hass.bus.fire('test_event1') + hass.pool.block_till_done() + hass.bus.fire('test_event2') + hass.pool.block_till_done() + hass.bus.fire('test_event3') + hass.pool.block_till_done() + + data = self._stream_next_event(req) + self.assertEqual('test_event1', data['event_type']) + data = self._stream_next_event(req) + self.assertEqual('test_event3', data['event_type']) + + def _stream_next_event(self, stream): + data = b'' + last_new_line = False + for dat in stream.iter_content(1): + if dat == b'\n' and last_new_line: + break + data += dat + last_new_line = dat == b'\n' + + conv = data.decode('utf-8').strip()[6:] + + return conv if conv == 'ping' else json.loads(conv) + + def _listen_count(self): + """ Return number of event listeners. """ + return sum(hass.bus.listeners.values())