mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 20:27:08 +00:00
Add WS subscription command for MQTT (#21696)
* Add WS subscription command for MQTT * Add test * Add check for connected * Rename event_listeners to subscriptions
This commit is contained in:
parent
fc85b3fc5f
commit
429bbc05dc
@ -25,7 +25,7 @@ from homeassistant.const import (
|
|||||||
CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE,
|
CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE,
|
||||||
EVENT_HOMEASSISTANT_STOP)
|
EVENT_HOMEASSISTANT_STOP)
|
||||||
from homeassistant.core import Event, ServiceCall, callback
|
from homeassistant.core import Event, ServiceCall, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||||
from homeassistant.helpers import config_validation as cv, template
|
from homeassistant.helpers import config_validation as cv, template
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
from homeassistant.helpers.typing import (
|
from homeassistant.helpers.typing import (
|
||||||
@ -35,6 +35,7 @@ from homeassistant.setup import async_prepare_setup_platform
|
|||||||
from homeassistant.util.async_ import (
|
from homeassistant.util.async_ import (
|
||||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||||
from homeassistant.util.logging import catch_log_exception
|
from homeassistant.util.logging import catch_log_exception
|
||||||
|
from homeassistant.components import websocket_api
|
||||||
|
|
||||||
# Loading the config flow file will register the flow
|
# Loading the config flow file will register the flow
|
||||||
from . import config_flow # noqa pylint: disable=unused-import
|
from . import config_flow # noqa pylint: disable=unused-import
|
||||||
@ -391,6 +392,8 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
|
|||||||
# This needs a better solution.
|
# This needs a better solution.
|
||||||
hass.data[DATA_MQTT_HASS_CONFIG] = config
|
hass.data[DATA_MQTT_HASS_CONFIG] = config
|
||||||
|
|
||||||
|
websocket_api.async_register_command(hass, websocket_subscribe)
|
||||||
|
|
||||||
if conf is None:
|
if conf is None:
|
||||||
# If we have a config entry, setup is done by that config entry.
|
# If we have a config entry, setup is done by that config entry.
|
||||||
# If there is no config entry, this should fail.
|
# If there is no config entry, this should fail.
|
||||||
@ -602,6 +605,7 @@ class MQTT:
|
|||||||
self.keepalive = keepalive
|
self.keepalive = keepalive
|
||||||
self.subscriptions = [] # type: List[Subscription]
|
self.subscriptions = [] # type: List[Subscription]
|
||||||
self.birth_message = birth_message
|
self.birth_message = birth_message
|
||||||
|
self.connected = False
|
||||||
self._mqttc = None # type: mqtt.Client
|
self._mqttc = None # type: mqtt.Client
|
||||||
self._paho_lock = asyncio.Lock(loop=hass.loop)
|
self._paho_lock = asyncio.Lock(loop=hass.loop)
|
||||||
|
|
||||||
@ -703,7 +707,10 @@ class MQTT:
|
|||||||
if any(other.topic == topic for other in self.subscriptions):
|
if any(other.topic == topic for other in self.subscriptions):
|
||||||
# Other subscriptions on topic remaining - don't unsubscribe.
|
# Other subscriptions on topic remaining - don't unsubscribe.
|
||||||
return
|
return
|
||||||
self.hass.async_create_task(self._async_unsubscribe(topic))
|
|
||||||
|
# Only unsubscribe if currently connected.
|
||||||
|
if self.connected:
|
||||||
|
self.hass.async_create_task(self._async_unsubscribe(topic))
|
||||||
|
|
||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
@ -743,6 +750,8 @@ class MQTT:
|
|||||||
self._mqttc.disconnect()
|
self._mqttc.disconnect()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
# Group subscriptions to only re-subscribe once for each topic.
|
# Group subscriptions to only re-subscribe once for each topic.
|
||||||
keyfunc = attrgetter('topic')
|
keyfunc = attrgetter('topic')
|
||||||
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
|
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
|
||||||
@ -782,6 +791,8 @@ class MQTT:
|
|||||||
|
|
||||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||||
"""Disconnected callback."""
|
"""Disconnected callback."""
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
# When disconnected because of calling disconnect()
|
# When disconnected because of calling disconnect()
|
||||||
if result_code == 0:
|
if result_code == 0:
|
||||||
return
|
return
|
||||||
@ -791,6 +802,7 @@ class MQTT:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if self._mqttc.reconnect() == 0:
|
if self._mqttc.reconnect() == 0:
|
||||||
|
self.connected = True
|
||||||
_LOGGER.info("Successfully reconnected to the MQTT server")
|
_LOGGER.info("Successfully reconnected to the MQTT server")
|
||||||
break
|
break
|
||||||
except socket.error:
|
except socket.error:
|
||||||
@ -1040,3 +1052,27 @@ class MqttEntityDeviceInfo(Entity):
|
|||||||
info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB])
|
info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB])
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.async_response
|
||||||
|
@websocket_api.websocket_command({
|
||||||
|
vol.Required('type'): 'mqtt/subscribe',
|
||||||
|
vol.Required('topic'): valid_subscribe_topic,
|
||||||
|
})
|
||||||
|
async def websocket_subscribe(hass, connection, msg):
|
||||||
|
"""Subscribe to a MQTT topic."""
|
||||||
|
if not connection.user.is_admin:
|
||||||
|
raise Unauthorized
|
||||||
|
|
||||||
|
async def forward_messages(topic: str, payload: str, qos: int):
|
||||||
|
"""Forward events to websocket."""
|
||||||
|
connection.send_message(websocket_api.event_message(msg['id'], {
|
||||||
|
'topic': topic,
|
||||||
|
'payload': payload,
|
||||||
|
'qos': qos,
|
||||||
|
}))
|
||||||
|
|
||||||
|
connection.subscriptions[msg['id']] = await async_subscribe(
|
||||||
|
hass, msg['topic'], forward_messages)
|
||||||
|
|
||||||
|
connection.send_message(websocket_api.result_message(msg['id']))
|
||||||
|
@ -14,6 +14,7 @@ ActiveConnection = connection.ActiveConnection
|
|||||||
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
||||||
error_message = messages.error_message
|
error_message = messages.error_message
|
||||||
result_message = messages.result_message
|
result_message = messages.result_message
|
||||||
|
event_message = messages.event_message
|
||||||
async_response = decorators.async_response
|
async_response = decorators.async_response
|
||||||
require_admin = decorators.require_admin
|
require_admin = decorators.require_admin
|
||||||
ws_require_user = decorators.ws_require_user
|
ws_require_user = decorators.ws_require_user
|
||||||
|
@ -24,15 +24,6 @@ def async_register_commands(hass):
|
|||||||
async_reg(handle_ping)
|
async_reg(handle_ping)
|
||||||
|
|
||||||
|
|
||||||
def event_message(iden, event):
|
|
||||||
"""Return an event message."""
|
|
||||||
return {
|
|
||||||
'id': iden,
|
|
||||||
'type': 'event',
|
|
||||||
'event': event.as_dict(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def pong_message(iden):
|
def pong_message(iden):
|
||||||
"""Return a pong message."""
|
"""Return a pong message."""
|
||||||
return {
|
return {
|
||||||
@ -59,9 +50,11 @@ def handle_subscribe_events(hass, connection, msg):
|
|||||||
if event.event_type == EVENT_TIME_CHANGED:
|
if event.event_type == EVENT_TIME_CHANGED:
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message(event_message(msg['id'], event))
|
connection.send_message(messages.event_message(
|
||||||
|
msg['id'], event.as_dict()
|
||||||
|
))
|
||||||
|
|
||||||
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
connection.subscriptions[msg['id']] = hass.bus.async_listen(
|
||||||
msg['event_type'], forward_events)
|
msg['event_type'], forward_events)
|
||||||
|
|
||||||
connection.send_message(messages.result_message(msg['id']))
|
connection.send_message(messages.result_message(msg['id']))
|
||||||
@ -79,8 +72,8 @@ def handle_unsubscribe_events(hass, connection, msg):
|
|||||||
"""
|
"""
|
||||||
subscription = msg['subscription']
|
subscription = msg['subscription']
|
||||||
|
|
||||||
if subscription in connection.event_listeners:
|
if subscription in connection.subscriptions:
|
||||||
connection.event_listeners.pop(subscription)()
|
connection.subscriptions.pop(subscription)()
|
||||||
connection.send_message(messages.result_message(msg['id']))
|
connection.send_message(messages.result_message(msg['id']))
|
||||||
else:
|
else:
|
||||||
connection.send_message(messages.error_message(
|
connection.send_message(messages.error_message(
|
||||||
|
@ -21,7 +21,7 @@ class ActiveConnection:
|
|||||||
else:
|
else:
|
||||||
self.refresh_token_id = None
|
self.refresh_token_id = None
|
||||||
|
|
||||||
self.event_listeners = {}
|
self.subscriptions = {}
|
||||||
self.last_id = 0
|
self.last_id = 0
|
||||||
|
|
||||||
def context(self, msg):
|
def context(self, msg):
|
||||||
@ -82,7 +82,7 @@ class ActiveConnection:
|
|||||||
@callback
|
@callback
|
||||||
def async_close(self):
|
def async_close(self):
|
||||||
"""Close down connection."""
|
"""Close down connection."""
|
||||||
for unsub in self.event_listeners.values():
|
for unsub in self.subscriptions.values():
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -40,3 +40,12 @@ def error_message(iden, code, message):
|
|||||||
'message': message,
|
'message': message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def event_message(iden, event):
|
||||||
|
"""Return an event message."""
|
||||||
|
return {
|
||||||
|
'id': iden,
|
||||||
|
'type': 'event',
|
||||||
|
'event': event,
|
||||||
|
}
|
||||||
|
@ -767,3 +767,37 @@ async def test_message_callback_exception_gets_logged(hass, caplog):
|
|||||||
assert \
|
assert \
|
||||||
"Exception in bad_handler when handling msg on 'test-topic':" \
|
"Exception in bad_handler when handling msg on 'test-topic':" \
|
||||||
" 'test'" in caplog.text
|
" 'test'" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mqtt_ws_subscription(hass, hass_ws_client):
|
||||||
|
"""Test MQTT websocket subscription."""
|
||||||
|
await async_mock_mqtt_component(hass)
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
await client.send_json({
|
||||||
|
'id': 5,
|
||||||
|
'type': 'mqtt/subscribe',
|
||||||
|
'topic': 'test-topic',
|
||||||
|
})
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response['success']
|
||||||
|
|
||||||
|
async_fire_mqtt_message(hass, 'test-topic', 'test1')
|
||||||
|
async_fire_mqtt_message(hass, 'test-topic', 'test2')
|
||||||
|
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response['event']['topic'] == 'test-topic'
|
||||||
|
assert response['event']['payload'] == 'test1'
|
||||||
|
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response['event']['topic'] == 'test-topic'
|
||||||
|
assert response['event']['payload'] == 'test2'
|
||||||
|
|
||||||
|
# Unsubscribe
|
||||||
|
await client.send_json({
|
||||||
|
'id': 8,
|
||||||
|
'type': 'unsubscribe_events',
|
||||||
|
'subscription': 5,
|
||||||
|
})
|
||||||
|
response = await client.receive_json()
|
||||||
|
assert response['success']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user