Add WS subscription command for MQTT (#21696)

* Add WS subscription command for MQTT

* Add test

* Add check for connected

* Rename event_listeners to subscriptions
This commit is contained in:
Paulus Schoutsen 2019-03-10 20:07:09 -07:00 committed by GitHub
parent fc85b3fc5f
commit 429bbc05dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 17 deletions

View File

@ -25,7 +25,7 @@ from homeassistant.const import (
CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE, CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE,
EVENT_HOMEASSISTANT_STOP) EVENT_HOMEASSISTANT_STOP)
from homeassistant.core import Event, ServiceCall, callback from homeassistant.core import Event, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ( from homeassistant.helpers.typing import (
@ -35,6 +35,7 @@ from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.async_ import ( from homeassistant.util.async_ import (
run_callback_threadsafe, run_coroutine_threadsafe) run_callback_threadsafe, run_coroutine_threadsafe)
from homeassistant.util.logging import catch_log_exception from homeassistant.util.logging import catch_log_exception
from homeassistant.components import websocket_api
# Loading the config flow file will register the flow # Loading the config flow file will register the flow
from . import config_flow # noqa pylint: disable=unused-import from . import config_flow # noqa pylint: disable=unused-import
@ -391,6 +392,8 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
# This needs a better solution. # This needs a better solution.
hass.data[DATA_MQTT_HASS_CONFIG] = config hass.data[DATA_MQTT_HASS_CONFIG] = config
websocket_api.async_register_command(hass, websocket_subscribe)
if conf is None: if conf is None:
# If we have a config entry, setup is done by that config entry. # If we have a config entry, setup is done by that config entry.
# If there is no config entry, this should fail. # If there is no config entry, this should fail.
@ -602,6 +605,7 @@ class MQTT:
self.keepalive = keepalive self.keepalive = keepalive
self.subscriptions = [] # type: List[Subscription] self.subscriptions = [] # type: List[Subscription]
self.birth_message = birth_message self.birth_message = birth_message
self.connected = False
self._mqttc = None # type: mqtt.Client self._mqttc = None # type: mqtt.Client
self._paho_lock = asyncio.Lock(loop=hass.loop) self._paho_lock = asyncio.Lock(loop=hass.loop)
@ -703,7 +707,10 @@ class MQTT:
if any(other.topic == topic for other in self.subscriptions): if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe. # Other subscriptions on topic remaining - don't unsubscribe.
return return
self.hass.async_create_task(self._async_unsubscribe(topic))
# Only unsubscribe if currently connected.
if self.connected:
self.hass.async_create_task(self._async_unsubscribe(topic))
return async_remove return async_remove
@ -743,6 +750,8 @@ class MQTT:
self._mqttc.disconnect() self._mqttc.disconnect()
return return
self.connected = True
# Group subscriptions to only re-subscribe once for each topic. # Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter('topic') keyfunc = attrgetter('topic')
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
@ -782,6 +791,8 @@ class MQTT:
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback.""" """Disconnected callback."""
self.connected = False
# When disconnected because of calling disconnect() # When disconnected because of calling disconnect()
if result_code == 0: if result_code == 0:
return return
@ -791,6 +802,7 @@ class MQTT:
while True: while True:
try: try:
if self._mqttc.reconnect() == 0: if self._mqttc.reconnect() == 0:
self.connected = True
_LOGGER.info("Successfully reconnected to the MQTT server") _LOGGER.info("Successfully reconnected to the MQTT server")
break break
except socket.error: except socket.error:
@ -1040,3 +1052,27 @@ class MqttEntityDeviceInfo(Entity):
info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB]) info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB])
return info return info
@websocket_api.async_response
@websocket_api.websocket_command({
vol.Required('type'): 'mqtt/subscribe',
vol.Required('topic'): valid_subscribe_topic,
})
async def websocket_subscribe(hass, connection, msg):
"""Subscribe to a MQTT topic."""
if not connection.user.is_admin:
raise Unauthorized
async def forward_messages(topic: str, payload: str, qos: int):
"""Forward events to websocket."""
connection.send_message(websocket_api.event_message(msg['id'], {
'topic': topic,
'payload': payload,
'qos': qos,
}))
connection.subscriptions[msg['id']] = await async_subscribe(
hass, msg['topic'], forward_messages)
connection.send_message(websocket_api.result_message(msg['id']))

View File

@ -14,6 +14,7 @@ ActiveConnection = connection.ActiveConnection
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
error_message = messages.error_message error_message = messages.error_message
result_message = messages.result_message result_message = messages.result_message
event_message = messages.event_message
async_response = decorators.async_response async_response = decorators.async_response
require_admin = decorators.require_admin require_admin = decorators.require_admin
ws_require_user = decorators.ws_require_user ws_require_user = decorators.ws_require_user

View File

@ -24,15 +24,6 @@ def async_register_commands(hass):
async_reg(handle_ping) async_reg(handle_ping)
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': 'event',
'event': event.as_dict(),
}
def pong_message(iden): def pong_message(iden):
"""Return a pong message.""" """Return a pong message."""
return { return {
@ -59,9 +50,11 @@ def handle_subscribe_events(hass, connection, msg):
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
return return
connection.send_message(event_message(msg['id'], event)) connection.send_message(messages.event_message(
msg['id'], event.as_dict()
))
connection.event_listeners[msg['id']] = hass.bus.async_listen( connection.subscriptions[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events) msg['event_type'], forward_events)
connection.send_message(messages.result_message(msg['id'])) connection.send_message(messages.result_message(msg['id']))
@ -79,8 +72,8 @@ def handle_unsubscribe_events(hass, connection, msg):
""" """
subscription = msg['subscription'] subscription = msg['subscription']
if subscription in connection.event_listeners: if subscription in connection.subscriptions:
connection.event_listeners.pop(subscription)() connection.subscriptions.pop(subscription)()
connection.send_message(messages.result_message(msg['id'])) connection.send_message(messages.result_message(msg['id']))
else: else:
connection.send_message(messages.error_message( connection.send_message(messages.error_message(

View File

@ -21,7 +21,7 @@ class ActiveConnection:
else: else:
self.refresh_token_id = None self.refresh_token_id = None
self.event_listeners = {} self.subscriptions = {}
self.last_id = 0 self.last_id = 0
def context(self, msg): def context(self, msg):
@ -82,7 +82,7 @@ class ActiveConnection:
@callback @callback
def async_close(self): def async_close(self):
"""Close down connection.""" """Close down connection."""
for unsub in self.event_listeners.values(): for unsub in self.subscriptions.values():
unsub() unsub()
@callback @callback

View File

@ -40,3 +40,12 @@ def error_message(iden, code, message):
'message': message, 'message': message,
}, },
} }
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': 'event',
'event': event,
}

View File

@ -767,3 +767,37 @@ async def test_message_callback_exception_gets_logged(hass, caplog):
assert \ assert \
"Exception in bad_handler when handling msg on 'test-topic':" \ "Exception in bad_handler when handling msg on 'test-topic':" \
" 'test'" in caplog.text " 'test'" in caplog.text
async def test_mqtt_ws_subscription(hass, hass_ws_client):
"""Test MQTT websocket subscription."""
await async_mock_mqtt_component(hass)
client = await hass_ws_client(hass)
await client.send_json({
'id': 5,
'type': 'mqtt/subscribe',
'topic': 'test-topic',
})
response = await client.receive_json()
assert response['success']
async_fire_mqtt_message(hass, 'test-topic', 'test1')
async_fire_mqtt_message(hass, 'test-topic', 'test2')
response = await client.receive_json()
assert response['event']['topic'] == 'test-topic'
assert response['event']['payload'] == 'test1'
response = await client.receive_json()
assert response['event']['topic'] == 'test-topic'
assert response['event']['payload'] == 'test2'
# Unsubscribe
await client.send_json({
'id': 8,
'type': 'unsubscribe_events',
'subscription': 5,
})
response = await client.receive_json()
assert response['success']