mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Add tests for API.stream
This commit is contained in:
parent
2e0042adb0
commit
027b891052
@ -125,7 +125,6 @@ def _handle_get_api_stream(handler, path_match, data):
|
|||||||
try:
|
try:
|
||||||
wfile.write(msg.encode("UTF-8"))
|
wfile.write(msg.encode("UTF-8"))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
handler.server.sessions.extend_validation(session_id)
|
|
||||||
except (IOError, ValueError):
|
except (IOError, ValueError):
|
||||||
# IOError: socket errors
|
# IOError: socket errors
|
||||||
# ValueError: raised when 'I/O operation on closed file'
|
# ValueError: raised when 'I/O operation on closed file'
|
||||||
@ -135,14 +134,14 @@ def _handle_get_api_stream(handler, path_match, data):
|
|||||||
""" Forwards events to the open request. """
|
""" Forwards events to the open request. """
|
||||||
nonlocal gracefully_closed
|
nonlocal gracefully_closed
|
||||||
|
|
||||||
if block.is_set() or event.event_type == EVENT_TIME_CHANGED or \
|
if block.is_set() or event.event_type == EVENT_TIME_CHANGED:
|
||||||
restrict and event.event_type not in restrict:
|
|
||||||
return
|
return
|
||||||
elif event.event_type == EVENT_HOMEASSISTANT_STOP:
|
elif event.event_type == EVENT_HOMEASSISTANT_STOP:
|
||||||
gracefully_closed = True
|
gracefully_closed = True
|
||||||
block.set()
|
block.set()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
handler.server.sessions.extend_validation(session_id)
|
||||||
write_message(json.dumps(event, cls=rem.JSONEncoder))
|
write_message(json.dumps(event, cls=rem.JSONEncoder))
|
||||||
|
|
||||||
handler.send_response(HTTP_OK)
|
handler.send_response(HTTP_OK)
|
||||||
@ -150,7 +149,11 @@ def _handle_get_api_stream(handler, path_match, data):
|
|||||||
session_id = handler.set_session_cookie_header()
|
session_id = handler.set_session_cookie_header()
|
||||||
handler.end_headers()
|
handler.end_headers()
|
||||||
|
|
||||||
hass.bus.listen(MATCH_ALL, forward_events)
|
if restrict:
|
||||||
|
for event in restrict:
|
||||||
|
hass.bus.listen(event, forward_events)
|
||||||
|
else:
|
||||||
|
hass.bus.listen(MATCH_ALL, forward_events)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
write_message(STREAM_PING_PAYLOAD)
|
write_message(STREAM_PING_PAYLOAD)
|
||||||
@ -164,7 +167,11 @@ def _handle_get_api_stream(handler, path_match, data):
|
|||||||
_LOGGER.info("Found broken event stream to %s, cleaning up",
|
_LOGGER.info("Found broken event stream to %s, cleaning up",
|
||||||
handler.client_address[0])
|
handler.client_address[0])
|
||||||
|
|
||||||
hass.bus.remove_listener(MATCH_ALL, forward_events)
|
if restrict:
|
||||||
|
for event in restrict:
|
||||||
|
hass.bus.remove_listener(event, forward_events)
|
||||||
|
else:
|
||||||
|
hass.bus.remove_listener(MATCH_ALL, forward_events)
|
||||||
|
|
||||||
|
|
||||||
def _handle_get_api_config(handler, path_match, data):
|
def _handle_get_api_config(handler, path_match, data):
|
||||||
|
@ -359,13 +359,13 @@ class RequestHandler(SimpleHTTPRequestHandler):
|
|||||||
def set_session_cookie_header(self):
|
def set_session_cookie_header(self):
|
||||||
""" Add the header for the session cookie and return session id. """
|
""" Add the header for the session cookie and return session id. """
|
||||||
if not self.authenticated:
|
if not self.authenticated:
|
||||||
return
|
return None
|
||||||
|
|
||||||
session_id = self.get_cookie_session_id()
|
session_id = self.get_cookie_session_id()
|
||||||
|
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
self.server.sessions.extend_validation(session_id)
|
self.server.sessions.extend_validation(session_id)
|
||||||
return
|
return session_id
|
||||||
|
|
||||||
self.send_header(
|
self.send_header(
|
||||||
'Set-Cookie',
|
'Set-Cookie',
|
||||||
@ -422,10 +422,10 @@ def session_valid_time():
|
|||||||
|
|
||||||
class SessionStore(object):
|
class SessionStore(object):
|
||||||
""" Responsible for storing and retrieving http sessions """
|
""" Responsible for storing and retrieving http sessions """
|
||||||
def __init__(self, enabled=True):
|
def __init__(self):
|
||||||
""" Set up the session store """
|
""" Set up the session store """
|
||||||
self._sessions = {}
|
self._sessions = {}
|
||||||
self.lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
@util.Throttle(SESSION_CLEAR_INTERVAL)
|
@util.Throttle(SESSION_CLEAR_INTERVAL)
|
||||||
def _remove_expired(self):
|
def _remove_expired(self):
|
||||||
@ -437,7 +437,7 @@ class SessionStore(object):
|
|||||||
|
|
||||||
def is_valid(self, key):
|
def is_valid(self, key):
|
||||||
""" Return True if a valid session is given. """
|
""" Return True if a valid session is given. """
|
||||||
with self.lock:
|
with self._lock:
|
||||||
self._remove_expired()
|
self._remove_expired()
|
||||||
|
|
||||||
return (key in self._sessions and
|
return (key in self._sessions and
|
||||||
@ -445,17 +445,19 @@ class SessionStore(object):
|
|||||||
|
|
||||||
def extend_validation(self, key):
|
def extend_validation(self, key):
|
||||||
""" Extend a session validation time. """
|
""" Extend a session validation time. """
|
||||||
with self.lock:
|
with self._lock:
|
||||||
|
if key not in self._sessions:
|
||||||
|
return
|
||||||
self._sessions[key] = session_valid_time()
|
self._sessions[key] = session_valid_time()
|
||||||
|
|
||||||
def destroy(self, key):
|
def destroy(self, key):
|
||||||
""" Destroy a session by key. """
|
""" Destroy a session by key. """
|
||||||
with self.lock:
|
with self._lock:
|
||||||
self._sessions.pop(key, None)
|
self._sessions.pop(key, None)
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
""" Creates a new session. """
|
""" Creates a new session. """
|
||||||
with self.lock:
|
with self._lock:
|
||||||
session_id = util.get_random_string(20)
|
session_id = util.get_random_string(20)
|
||||||
|
|
||||||
while session_id in self._sessions:
|
while session_id in self._sessions:
|
||||||
|
@ -5,10 +5,11 @@ tests.test_component_http
|
|||||||
Tests Home Assistant HTTP component does what it should do.
|
Tests Home Assistant HTTP component does what it should do.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access,too-many-public-methods
|
# pylint: disable=protected-access,too-many-public-methods
|
||||||
import unittest
|
from contextlib import closing
|
||||||
import json
|
import json
|
||||||
from unittest.mock import patch
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -415,3 +416,62 @@ class TestAPI(unittest.TestCase):
|
|||||||
}),
|
}),
|
||||||
headers=HA_HEADERS)
|
headers=HA_HEADERS)
|
||||||
self.assertEqual(200, req.status_code)
|
self.assertEqual(200, req.status_code)
|
||||||
|
|
||||||
|
def test_stream(self):
|
||||||
|
listen_count = self._listen_count()
|
||||||
|
with closing(requests.get(_url(const.URL_API_STREAM),
|
||||||
|
stream=True, headers=HA_HEADERS)) as req:
|
||||||
|
|
||||||
|
self.assertEqual(listen_count + 1, self._listen_count())
|
||||||
|
data = self._stream_next_event(req)
|
||||||
|
|
||||||
|
self.assertEqual('ping', data)
|
||||||
|
|
||||||
|
hass.bus.fire('test_event')
|
||||||
|
hass.pool.block_till_done()
|
||||||
|
|
||||||
|
data = self._stream_next_event(req)
|
||||||
|
|
||||||
|
self.assertEqual('test_event', data['event_type'])
|
||||||
|
|
||||||
|
def test_stream_with_restricted(self):
|
||||||
|
listen_count = self._listen_count()
|
||||||
|
with closing(requests.get(_url(const.URL_API_STREAM),
|
||||||
|
data=json.dumps({
|
||||||
|
'restrict': 'test_event1,test_event3'}),
|
||||||
|
stream=True, headers=HA_HEADERS)) as req:
|
||||||
|
|
||||||
|
self.assertEqual(listen_count + 2, self._listen_count())
|
||||||
|
|
||||||
|
data = self._stream_next_event(req)
|
||||||
|
|
||||||
|
self.assertEqual('ping', data)
|
||||||
|
|
||||||
|
hass.bus.fire('test_event1')
|
||||||
|
hass.pool.block_till_done()
|
||||||
|
hass.bus.fire('test_event2')
|
||||||
|
hass.pool.block_till_done()
|
||||||
|
hass.bus.fire('test_event3')
|
||||||
|
hass.pool.block_till_done()
|
||||||
|
|
||||||
|
data = self._stream_next_event(req)
|
||||||
|
self.assertEqual('test_event1', data['event_type'])
|
||||||
|
data = self._stream_next_event(req)
|
||||||
|
self.assertEqual('test_event3', data['event_type'])
|
||||||
|
|
||||||
|
def _stream_next_event(self, stream):
|
||||||
|
data = b''
|
||||||
|
last_new_line = False
|
||||||
|
for dat in stream.iter_content(1):
|
||||||
|
if dat == b'\n' and last_new_line:
|
||||||
|
break
|
||||||
|
data += dat
|
||||||
|
last_new_line = dat == b'\n'
|
||||||
|
|
||||||
|
conv = data.decode('utf-8').strip()[6:]
|
||||||
|
|
||||||
|
return conv if conv == 'ping' else json.loads(conv)
|
||||||
|
|
||||||
|
def _listen_count(self):
|
||||||
|
""" Return number of event listeners. """
|
||||||
|
return sum(hass.bus.listeners.values())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user