mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add tests for API.stream
This commit is contained in:
parent
2e0042adb0
commit
027b891052
@ -125,7 +125,6 @@ def _handle_get_api_stream(handler, path_match, data):
|
||||
try:
|
||||
wfile.write(msg.encode("UTF-8"))
|
||||
wfile.flush()
|
||||
handler.server.sessions.extend_validation(session_id)
|
||||
except (IOError, ValueError):
|
||||
# IOError: socket errors
|
||||
# ValueError: raised when 'I/O operation on closed file'
|
||||
@ -135,14 +134,14 @@ def _handle_get_api_stream(handler, path_match, data):
|
||||
""" Forwards events to the open request. """
|
||||
nonlocal gracefully_closed
|
||||
|
||||
if block.is_set() or event.event_type == EVENT_TIME_CHANGED or \
|
||||
restrict and event.event_type not in restrict:
|
||||
if block.is_set() or event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
elif event.event_type == EVENT_HOMEASSISTANT_STOP:
|
||||
gracefully_closed = True
|
||||
block.set()
|
||||
return
|
||||
|
||||
handler.server.sessions.extend_validation(session_id)
|
||||
write_message(json.dumps(event, cls=rem.JSONEncoder))
|
||||
|
||||
handler.send_response(HTTP_OK)
|
||||
@ -150,7 +149,11 @@ def _handle_get_api_stream(handler, path_match, data):
|
||||
session_id = handler.set_session_cookie_header()
|
||||
handler.end_headers()
|
||||
|
||||
hass.bus.listen(MATCH_ALL, forward_events)
|
||||
if restrict:
|
||||
for event in restrict:
|
||||
hass.bus.listen(event, forward_events)
|
||||
else:
|
||||
hass.bus.listen(MATCH_ALL, forward_events)
|
||||
|
||||
while True:
|
||||
write_message(STREAM_PING_PAYLOAD)
|
||||
@ -164,7 +167,11 @@ def _handle_get_api_stream(handler, path_match, data):
|
||||
_LOGGER.info("Found broken event stream to %s, cleaning up",
|
||||
handler.client_address[0])
|
||||
|
||||
hass.bus.remove_listener(MATCH_ALL, forward_events)
|
||||
if restrict:
|
||||
for event in restrict:
|
||||
hass.bus.remove_listener(event, forward_events)
|
||||
else:
|
||||
hass.bus.remove_listener(MATCH_ALL, forward_events)
|
||||
|
||||
|
||||
def _handle_get_api_config(handler, path_match, data):
|
||||
|
@ -359,13 +359,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
||||
def set_session_cookie_header(self):
|
||||
""" Add the header for the session cookie and return session id. """
|
||||
if not self.authenticated:
|
||||
return
|
||||
return None
|
||||
|
||||
session_id = self.get_cookie_session_id()
|
||||
|
||||
if session_id is not None:
|
||||
self.server.sessions.extend_validation(session_id)
|
||||
return
|
||||
return session_id
|
||||
|
||||
self.send_header(
|
||||
'Set-Cookie',
|
||||
@ -422,10 +422,10 @@ def session_valid_time():
|
||||
|
||||
class SessionStore(object):
|
||||
""" Responsible for storing and retrieving http sessions """
|
||||
def __init__(self, enabled=True):
|
||||
def __init__(self):
|
||||
""" Set up the session store """
|
||||
self._sessions = {}
|
||||
self.lock = threading.RLock()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@util.Throttle(SESSION_CLEAR_INTERVAL)
|
||||
def _remove_expired(self):
|
||||
@ -437,7 +437,7 @@ class SessionStore(object):
|
||||
|
||||
def is_valid(self, key):
|
||||
""" Return True if a valid session is given. """
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
self._remove_expired()
|
||||
|
||||
return (key in self._sessions and
|
||||
@ -445,17 +445,19 @@ class SessionStore(object):
|
||||
|
||||
def extend_validation(self, key):
|
||||
""" Extend a session validation time. """
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
if key not in self._sessions:
|
||||
return
|
||||
self._sessions[key] = session_valid_time()
|
||||
|
||||
def destroy(self, key):
|
||||
""" Destroy a session by key. """
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
self._sessions.pop(key, None)
|
||||
|
||||
def create(self):
|
||||
""" Creates a new session. """
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
session_id = util.get_random_string(20)
|
||||
|
||||
while session_id in self._sessions:
|
||||
|
@ -5,10 +5,11 @@ tests.test_component_http
|
||||
Tests Home Assistant HTTP component does what it should do.
|
||||
"""
|
||||
# pylint: disable=protected-access,too-many-public-methods
|
||||
import unittest
|
||||
from contextlib import closing
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import requests
|
||||
|
||||
@ -415,3 +416,62 @@ class TestAPI(unittest.TestCase):
|
||||
}),
|
||||
headers=HA_HEADERS)
|
||||
self.assertEqual(200, req.status_code)
|
||||
|
||||
def test_stream(self):
|
||||
listen_count = self._listen_count()
|
||||
with closing(requests.get(_url(const.URL_API_STREAM),
|
||||
stream=True, headers=HA_HEADERS)) as req:
|
||||
|
||||
self.assertEqual(listen_count + 1, self._listen_count())
|
||||
data = self._stream_next_event(req)
|
||||
|
||||
self.assertEqual('ping', data)
|
||||
|
||||
hass.bus.fire('test_event')
|
||||
hass.pool.block_till_done()
|
||||
|
||||
data = self._stream_next_event(req)
|
||||
|
||||
self.assertEqual('test_event', data['event_type'])
|
||||
|
||||
def test_stream_with_restricted(self):
|
||||
listen_count = self._listen_count()
|
||||
with closing(requests.get(_url(const.URL_API_STREAM),
|
||||
data=json.dumps({
|
||||
'restrict': 'test_event1,test_event3'}),
|
||||
stream=True, headers=HA_HEADERS)) as req:
|
||||
|
||||
self.assertEqual(listen_count + 2, self._listen_count())
|
||||
|
||||
data = self._stream_next_event(req)
|
||||
|
||||
self.assertEqual('ping', data)
|
||||
|
||||
hass.bus.fire('test_event1')
|
||||
hass.pool.block_till_done()
|
||||
hass.bus.fire('test_event2')
|
||||
hass.pool.block_till_done()
|
||||
hass.bus.fire('test_event3')
|
||||
hass.pool.block_till_done()
|
||||
|
||||
data = self._stream_next_event(req)
|
||||
self.assertEqual('test_event1', data['event_type'])
|
||||
data = self._stream_next_event(req)
|
||||
self.assertEqual('test_event3', data['event_type'])
|
||||
|
||||
def _stream_next_event(self, stream):
|
||||
data = b''
|
||||
last_new_line = False
|
||||
for dat in stream.iter_content(1):
|
||||
if dat == b'\n' and last_new_line:
|
||||
break
|
||||
data += dat
|
||||
last_new_line = dat == b'\n'
|
||||
|
||||
conv = data.decode('utf-8').strip()[6:]
|
||||
|
||||
return conv if conv == 'ping' else json.loads(conv)
|
||||
|
||||
def _listen_count(self):
|
||||
""" Return number of event listeners. """
|
||||
return sum(hass.bus.listeners.values())
|
||||
|
Loading…
x
Reference in New Issue
Block a user