Get EventStream working

This commit is contained in:
Paulus Schoutsen 2016-05-15 23:54:14 -07:00
parent fe794d7fd8
commit 794ff20987
4 changed files with 130 additions and 126 deletions

View File

@ -73,102 +73,94 @@ class APIEventStream(HomeAssistantView):
def get(self, request): def get(self, request):
"""Provide a streaming interface for the event bus.""" """Provide a streaming interface for the event bus."""
from eventlet.queue import Empty
import eventlet import eventlet
from eventlet import queue as eventlet_queue import homeassistant.util.eventlet as eventlet_util
import queue as thread_queue
from threading import Event
from time import time
to_write = thread_queue.Queue() cur_hub = eventlet.hubs.get_hub()
# to_write = eventlet.Queue() request.environ['eventlet.minimum_write_chunk_size'] = 0
to_write = eventlet.Queue()
stop_obj = object() stop_obj = object()
hass = self.hass attached_ping = None
connection_closed = Event()
restrict = request.args.get('restrict') restrict = request.args.get('restrict')
if restrict: if restrict:
restrict = restrict.split(',') restrict = restrict.split(',')
restrict = False def thread_ping(now):
"""Called from time thread to add ping to queue."""
_LOGGER.debug('STREAM %s PING', id(stop_obj))
eventlet_util.spawn(cur_hub, to_write.put, STREAM_PING_PAYLOAD)
def ping(now): def thread_forward_events(event):
"""Add a ping message to queue."""
print(id(stop_obj), 'ping')
to_write.put(STREAM_PING_PAYLOAD)
def forward_events(event):
"""Forward events to the open request.""" """Forward events to the open request."""
print(id(stop_obj), 'forwarding', event)
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
pass return
elif event.event_type == EVENT_HOMEASSISTANT_STOP:
to_write.put(stop_obj) _LOGGER.debug('STREAM %s FORWARDING %s', id(stop_obj), event)
if event.event_type == EVENT_HOMEASSISTANT_STOP:
data = stop_obj
else: else:
to_write.put(json.dumps(event, cls=rem.JSONEncoder)) data = json.dumps(event, cls=rem.JSONEncoder)
eventlet_util.spawn(cur_hub, to_write.put, data)
def cleanup():
"""Clean up HA listeners."""
_LOGGER.debug("STREAM %s CLEANING UP", id(stop_obj))
self.hass.bus.remove_listener(EVENT_TIME_CHANGED, attached_ping)
if restrict:
for event in restrict:
self.hass.bus.remove_listener(event, thread_forward_events)
else:
self.hass.bus.remove_listener(MATCH_ALL, thread_forward_events)
def stream(): def stream():
"""Stream events to response.""" """Stream events to response."""
nonlocal attached_ping
if restrict: if restrict:
for event_type in restrict: for event_type in restrict:
hass.bus.listen(event_type, forward_events) self.hass.bus.listen(event_type, thread_forward_events)
else: else:
hass.bus.listen(MATCH_ALL, forward_events) self.hass.bus.listen(MATCH_ALL, thread_forward_events)
attached_ping = track_utc_time_change( attached_ping = track_utc_time_change(
hass, ping, second=(0, 30)) self.hass, thread_ping, second=range(0, 60, 3)) #(0, 30))
print(id(stop_obj), 'attached goodness') _LOGGER.debug('STREAM %s ATTACHED', id(stop_obj))
while not connection_closed.is_set(): while True:
try: try:
print(id(stop_obj), "Try getting obj") # Somehow our queue.get takes too long to
payload = to_write.get(False) # be notified of arrival of object. Probably
# because of our spawning on hub in other thread
# hack. Because current goal is to get this out,
# We just timeout every second because it will
# return right away if qsize() > 0.
# So yes, we're basically polling :(
# socket.io anyone?
payload = to_write.get(timeout=1)
if payload is stop_obj: if payload is stop_obj:
break break
msg = "data: {}\n\n".format(payload) msg = "data: {}\n\n".format(payload)
print(id(stop_obj), msg) _LOGGER.debug('STREAM %s WRITING %s', id(stop_obj),
msg.strip())
yield msg.encode("UTF-8") yield msg.encode("UTF-8")
except eventlet_queue.Empty: except Empty:
print(id(stop_obj), "queue empty, sleep 0.5")
eventlet.sleep(.5)
except GeneratorExit:
pass pass
except GeneratorExit:
_LOGGER.debug('STREAM %s RESPONSE CLOSED', id(stop_obj))
break
print(id(stop_obj), "cleaning up") cleanup()
hass.bus.remove_listener(EVENT_TIME_CHANGED, attached_ping) return self.Response(stream(), mimetype='text/event-stream')
if restrict:
for event in restrict:
hass.bus.remove_listener(event, forward_events)
else:
hass.bus.remove_listener(MATCH_ALL, forward_events)
resp = self.Response(stream(), mimetype='text/event-stream')
def closing():
print()
print()
print()
print()
print()
print()
print()
print()
print(id(stop_obj), "CLOSING RESPONSE")
print()
print()
print()
print()
print()
print()
print()
connection_closed.set()
resp.call_on_close(closing)
return resp
class APIConfigView(HomeAssistantView): class APIConfigView(HomeAssistantView):

View File

@ -213,9 +213,6 @@ class HomeAssistantWSGI(object):
"""Register a folder to serve as a static path.""" """Register a folder to serve as a static path."""
from static import Cling from static import Cling
if url_root in self.extra_apps:
_LOGGER.warning("Static path '%s' is being overwritten", path)
headers = [] headers = []
if not self.development: if not self.development:
@ -228,7 +225,14 @@ class HomeAssistantWSGI(object):
"public, max-age={}".format(cache_time) "public, max-age={}".format(cache_time)
}) })
self.extra_apps[url_root] = Cling(path, headers=headers) self.register_wsgi_app(url_root, Cling(path, headers=headers))
def register_wsgi_app(self, url_root, app):
"""Register a path to serve a WSGI app."""
if url_root in self.extra_apps:
_LOGGER.warning("Url root '%s' is being overwritten", url_root)
self.extra_apps[url_root] = app
def start(self): def start(self):
"""Start the wsgi server.""" """Start the wsgi server."""
@ -248,20 +252,21 @@ class HomeAssistantWSGI(object):
) )
from werkzeug.routing import RequestRedirect from werkzeug.routing import RequestRedirect
adapter = self.url_map.bind_to_environ(request.environ) with request:
try: adapter = self.url_map.bind_to_environ(request.environ)
endpoint, values = adapter.match() try:
return self.views[endpoint].handle_request(request, **values) endpoint, values = adapter.match()
except RequestRedirect as ex: return self.views[endpoint].handle_request(request, **values)
return ex except RequestRedirect as ex:
except BadRequest as ex: return ex
return self._handle_error(request, str(ex), 400) except BadRequest as ex:
except NotFound as ex: return self._handle_error(request, str(ex), 400)
return self._handle_error(request, str(ex), 404) except NotFound as ex:
except MethodNotAllowed as ex: return self._handle_error(request, str(ex), 404)
return self._handle_error(request, str(ex), 405) except MethodNotAllowed as ex:
except Unauthorized as ex: return self._handle_error(request, str(ex), 405)
return self._handle_error(request, str(ex), 401) except Unauthorized as ex:
return self._handle_error(request, str(ex), 401)
# TODO This long chain of except blocks is silly. _handle_error should # TODO This long chain of except blocks is silly. _handle_error should
# just take the exception as an argument and parse the status code # just take the exception as an argument and parse the status code
# itself # itself

View File

@ -0,0 +1,9 @@
"""Eventlet util methods."""
def spawn(hub, func, *args, **kwargs):
"""Spawns a function on specified hub."""
import eventlet
g = eventlet.greenthread.GreenThread(hub.greenlet)
hub.schedule_call_global(0, g.switch, func, args, kwargs)
return g

View File

@ -435,63 +435,61 @@ class TestAPI(unittest.TestCase):
headers=HA_HEADERS) headers=HA_HEADERS)
self.assertEqual(200, req.status_code) self.assertEqual(200, req.status_code)
# def test_stream(self): def test_stream(self):
# """Test the stream.""" """Test the stream."""
# listen_count = self._listen_count() listen_count = self._listen_count()
# with closing(requests.get(_url(const.URL_API_STREAM), with closing(requests.get(_url(const.URL_API_STREAM),
# stream=True, headers=HA_HEADERS)) as req: stream=True, headers=HA_HEADERS)) as req:
# self.assertEqual(listen_count + 1, self._listen_count()) self.assertEqual(listen_count + 2, self._listen_count())
# # eventlet.sleep(1) hass.bus.fire('test_event')
# print('firing event') hass.pool.block_till_done()
# hass.bus.fire('test_event') data = self._stream_next_event(req)
# hass.pool.block_till_done()
# data = self._stream_next_event(req) self.assertEqual('test_event', data['event_type'])
# self.assertEqual('test_event', data['event_type']) def test_stream_with_restricted(self):
"""Test the stream with restrictions."""
listen_count = self._listen_count()
url = _url('{}?restrict=test_event1,test_event3'.format(
const.URL_API_STREAM))
# def test_stream_with_restricted(self): with closing(requests.get(url, stream=True,
# """Test the stream with restrictions.""" headers=HA_HEADERS)) as req:
# listen_count = self._listen_count()
# with closing(requests.get(_url(const.URL_API_STREAM),
# data=json.dumps({
# 'restrict':
# 'test_event1,test_event3'}),
# stream=True, headers=HA_HEADERS)) as req:
# data = self._stream_next_event(req) self.assertEqual(listen_count + 3, self._listen_count())
# self.assertEqual('ping', data)
# self.assertEqual(listen_count + 2, self._listen_count()) hass.bus.fire('test_event1')
hass.pool.block_till_done()
hass.bus.fire('test_event2')
hass.pool.block_till_done()
hass.bus.fire('test_event3')
hass.pool.block_till_done()
# hass.bus.fire('test_event1') data = self._stream_next_event(req)
# hass.pool.block_till_done() self.assertEqual('test_event1', data['event_type'])
# hass.bus.fire('test_event2') data = self._stream_next_event(req)
# hass.pool.block_till_done() self.assertEqual('test_event3', data['event_type'])
# hass.bus.fire('test_event3')
# hass.pool.block_till_done()
# data = self._stream_next_event(req)
# self.assertEqual('test_event1', data['event_type'])
# data = self._stream_next_event(req)
# self.assertEqual('test_event3', data['event_type'])
def _stream_next_event(self, stream): def _stream_next_event(self, stream):
"""Test the stream for next event.""" """Read the stream for next event while ignoring ping."""
data = b'' while True:
last_new_line = False data = b''
for dat in stream.iter_content(1): last_new_line = False
if dat == b'\n' and last_new_line: for dat in stream.iter_content(1):
if dat == b'\n' and last_new_line:
break
data += dat
last_new_line = dat == b'\n'
conv = data.decode('utf-8').strip()[6:]
if conv != 'ping':
break break
data += dat
last_new_line = dat == b'\n'
conv = data.decode('utf-8').strip()[6:] return json.loads(conv)
return conv if conv == 'ping' else json.loads(conv)
def _listen_count(self): def _listen_count(self):
"""Return number of event listeners.""" """Return number of event listeners."""