diff --git a/homeassistant/components/ozw/__init__.py b/homeassistant/components/ozw/__init__.py index ae79850d96f..f57d737bf36 100644 --- a/homeassistant/components/ozw/__init__.py +++ b/homeassistant/components/ozw/__init__.py @@ -26,7 +26,14 @@ from homeassistant.helpers.device_registry import async_get_registry as get_dev_ from homeassistant.helpers.dispatcher import async_dispatcher_send from . import const -from .const import DATA_UNSUBSCRIBE, DOMAIN, PLATFORMS, TOPIC_OPENZWAVE +from .const import ( + DATA_UNSUBSCRIBE, + DOMAIN, + MANAGER, + OPTIONS, + PLATFORMS, + TOPIC_OPENZWAVE, +) from .discovery import DISCOVERY_SCHEMAS, check_node_schema, check_value_schema from .entity import ( ZWaveDeviceEntityValues, @@ -35,7 +42,7 @@ from .entity import ( create_value_id, ) from .services import ZWaveServices -from .websocket_api import ZWaveWebsocketApi +from .websocket_api import async_register_api _LOGGER = logging.getLogger(__name__) @@ -68,6 +75,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): options = OZWOptions(send_message=send_message, topic_prefix=f"{TOPIC_OPENZWAVE}/") manager = OZWManager(options) + hass.data[DOMAIN][MANAGER] = manager + hass.data[DOMAIN][OPTIONS] = options + @callback def async_node_added(node): # Caution: This is also called on (re)start. @@ -209,8 +219,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): services.async_register() # Register WebSocket API - ws_api = ZWaveWebsocketApi(hass, manager) - ws_api.async_register_api() + async_register_api(hass) @callback def async_receive_message(msg): diff --git a/homeassistant/components/ozw/const.py b/homeassistant/components/ozw/const.py index 93aa8da4b79..91809298382 100644 --- a/homeassistant/components/ozw/const.py +++ b/homeassistant/components/ozw/const.py @@ -20,6 +20,8 @@ PLATFORMS = [ SENSOR_DOMAIN, SWITCH_DOMAIN, ] +MANAGER = "manager" +OPTIONS = "options" # MQTT Topics TOPIC_OPENZWAVE = "OpenZWave" diff --git a/homeassistant/components/ozw/manifest.json b/homeassistant/components/ozw/manifest.json index d2cf4772bb1..191411c36ee 100644 --- a/homeassistant/components/ozw/manifest.json +++ b/homeassistant/components/ozw/manifest.json @@ -14,4 +14,4 @@ "@marcelveldt", "@MartinHjelmare" ] -} +} \ No newline at end of file diff --git a/homeassistant/components/ozw/websocket_api.py b/homeassistant/components/ozw/websocket_api.py index e7c8b047f84..1b62c892f93 100644 --- a/homeassistant/components/ozw/websocket_api.py +++ b/homeassistant/components/ozw/websocket_api.py @@ -2,11 +2,14 @@ import logging +from openzwavemqtt.const import EVENT_NODE_ADDED, EVENT_NODE_CHANGED import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.core import callback +from .const import DOMAIN, MANAGER, OPTIONS + _LOGGER = logging.getLogger(__name__) TYPE = "type" @@ -15,101 +18,159 @@ OZW_INSTANCE = "ozw_instance" NODE_ID = "node_id" -class ZWaveWebsocketApi: - """Class that holds our websocket api commands.""" +@callback +def async_register_api(hass): + """Register all of our api endpoints.""" + websocket_api.async_register_command(hass, websocket_network_status) + websocket_api.async_register_command(hass, websocket_node_metadata) + websocket_api.async_register_command(hass, websocket_node_status) + websocket_api.async_register_command(hass, websocket_node_statistics) + websocket_api.async_register_command(hass, websocket_refresh_node_info) - def __init__(self, hass, manager): - """Initialize with both hass and ozwmanager objects.""" - self._hass = hass - self._manager = manager + +@websocket_api.websocket_command( + { + vol.Required(TYPE): "ozw/network_status", + vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), + } +) +def websocket_network_status(hass, connection, msg): + """Get Z-Wave network status.""" + + manager = hass.data[DOMAIN][MANAGER] + connection.send_result( + msg[ID], + { + "state": manager.get_instance(msg[OZW_INSTANCE]).get_status().status, + OZW_INSTANCE: msg[OZW_INSTANCE], + }, + ) + + +@websocket_api.websocket_command( + { + vol.Required(TYPE): "ozw/node_status", + vol.Required(NODE_ID): vol.Coerce(int), + vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), + } +) +def websocket_node_status(hass, connection, msg): + """Get the status for a Z-Wave node.""" + manager = hass.data[DOMAIN][MANAGER] + node = manager.get_instance(msg[OZW_INSTANCE]).get_node(msg[NODE_ID]) + connection.send_result( + msg[ID], + { + "node_query_stage": node.node_query_stage, + "node_id": node.node_id, + "is_zwave_plus": node.is_zwave_plus, + "is_awake": node.is_awake, + "is_failed": node.is_failed, + "node_baud_rate": node.node_baud_rate, + "is_beaming": node.is_beaming, + "is_flirs": node.is_flirs, + "is_routing": node.is_routing, + "is_securityv1": node.is_securityv1, + "node_basic_string": node.node_basic_string, + "node_generic_string": node.node_generic_string, + "node_specific_string": node.node_specific_string, + "neighbors": node.neighbors, + OZW_INSTANCE: msg[OZW_INSTANCE], + }, + ) + + +@websocket_api.websocket_command( + { + vol.Required(TYPE): "ozw/node_metadata", + vol.Required(NODE_ID): vol.Coerce(int), + vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), + } +) +def websocket_node_metadata(hass, connection, msg): + """Get the metadata for a Z-Wave node.""" + manager = hass.data[DOMAIN][MANAGER] + node = manager.get_instance(msg[OZW_INSTANCE]).get_node(msg[NODE_ID]) + connection.send_result( + msg[ID], + { + "metadata": node.meta_data, + NODE_ID: node.node_id, + OZW_INSTANCE: msg[OZW_INSTANCE], + }, + ) + + +@websocket_api.websocket_command( + { + vol.Required(TYPE): "ozw/node_statistics", + vol.Required(NODE_ID): vol.Coerce(int), + vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), + } +) +def websocket_node_statistics(hass, connection, msg): + """Get the statistics for a Z-Wave node.""" + manager = hass.data[DOMAIN][MANAGER] + stats = ( + manager.get_instance(msg[OZW_INSTANCE]).get_node(msg[NODE_ID]).get_statistics() + ) + connection.send_result( + msg[ID], + { + "node_id": msg[NODE_ID], + "send_count": stats.send_count, + "sent_failed": stats.sent_failed, + "retries": stats.retries, + "last_request_rtt": stats.last_request_rtt, + "last_response_rtt": stats.last_response_rtt, + "average_request_rtt": stats.average_request_rtt, + "average_response_rtt": stats.average_response_rtt, + "received_packets": stats.received_packets, + "received_dup_packets": stats.received_dup_packets, + "received_unsolicited": stats.received_unsolicited, + OZW_INSTANCE: msg[OZW_INSTANCE], + }, + ) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required(TYPE): "ozw/refresh_node_info", + vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), + vol.Required(NODE_ID): vol.Coerce(int), + } +) +def websocket_refresh_node_info(hass, connection, msg): + """Tell OpenZWave to re-interview a node.""" + + manager = hass.data[DOMAIN][MANAGER] + options = hass.data[DOMAIN][OPTIONS] @callback - def async_register_api(self): - """Register all of our api endpoints.""" - websocket_api.async_register_command(self._hass, self.websocket_network_status) - websocket_api.async_register_command(self._hass, self.websocket_node_status) - websocket_api.async_register_command(self._hass, self.websocket_node_statistics) + def forward_node(node): + """Forward node events to websocket.""" + if node.node_id != msg[NODE_ID]: + return - @websocket_api.websocket_command( - { - vol.Required(TYPE): "ozw/network_status", - vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), + forward_data = { + "type": "node_updated", + "node_query_stage": node.node_query_stage, } - ) - def websocket_network_status(self, hass, connection, msg): - """Get Z-Wave network status.""" + connection.send_message(websocket_api.event_message(msg["id"], forward_data)) - connection.send_result( - msg[ID], - { - "state": self._manager.get_instance(msg[OZW_INSTANCE]) - .get_status() - .status, - OZW_INSTANCE: msg[OZW_INSTANCE], - }, - ) + @callback + def async_cleanup() -> None: + """Remove signal listeners.""" + for unsub in unsubs: + unsub() - @websocket_api.websocket_command( - { - vol.Required(TYPE): "ozw/node_status", - vol.Required(NODE_ID): vol.Coerce(int), - vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), - } - ) - def websocket_node_status(self, hass, connection, msg): - """Get the status for a Z-Wave node.""" + connection.subscriptions[msg["id"]] = async_cleanup + unsubs = [ + options.listen(EVENT_NODE_CHANGED, forward_node), + options.listen(EVENT_NODE_ADDED, forward_node), + ] - node = self._manager.get_instance(msg[OZW_INSTANCE]).get_node(msg[NODE_ID]) - connection.send_result( - msg[ID], - { - "node_query_stage": node.node_query_stage, - "node_id": node.node_id, - "is_zwave_plus": node.is_zwave_plus, - "is_awake": node.is_awake, - "is_failed": node.is_failed, - "node_baud_rate": node.node_baud_rate, - "is_beaming": node.is_beaming, - "is_flirs": node.is_flirs, - "is_routing": node.is_routing, - "is_securityv1": node.is_securityv1, - "node_basic_string": node.node_basic_string, - "node_generic_string": node.node_generic_string, - "node_specific_string": node.node_specific_string, - "neighbors": node.neighbors, - OZW_INSTANCE: msg[OZW_INSTANCE], - }, - ) - - @websocket_api.websocket_command( - { - vol.Required(TYPE): "ozw/node_statistics", - vol.Required(NODE_ID): vol.Coerce(int), - vol.Optional(OZW_INSTANCE, default=1): vol.Coerce(int), - } - ) - def websocket_node_statistics(self, hass, connection, msg): - """Get the statistics for a Z-Wave node.""" - - stats = ( - self._manager.get_instance(msg[OZW_INSTANCE]) - .get_node(msg[NODE_ID]) - .get_statistics() - ) - connection.send_result( - msg[ID], - { - "node_id": msg[NODE_ID], - "send_count": stats.send_count, - "sent_failed": stats.sent_failed, - "retries": stats.retries, - "last_request_rtt": stats.last_request_rtt, - "last_response_rtt": stats.last_response_rtt, - "average_request_rtt": stats.average_request_rtt, - "average_response_rtt": stats.average_response_rtt, - "received_packets": stats.received_packets, - "received_dup_packets": stats.received_dup_packets, - "received_unsolicited": stats.received_unsolicited, - OZW_INSTANCE: msg[OZW_INSTANCE], - }, - ) + instance = manager.get_instance(msg[OZW_INSTANCE]) + instance.refresh_node(msg[NODE_ID]) + connection.send_result(msg["id"]) diff --git a/tests/components/ozw/test_websocket_api.py b/tests/components/ozw/test_websocket_api.py index 13ba6f2152c..bee3a828c5a 100644 --- a/tests/components/ozw/test_websocket_api.py +++ b/tests/components/ozw/test_websocket_api.py @@ -2,7 +2,9 @@ from homeassistant.components.ozw.websocket_api import ID, NODE_ID, OZW_INSTANCE, TYPE -from .common import setup_ozw +from .common import MQTTMessage, setup_ozw + +from tests.async_mock import patch async def test_websocket_api(hass, generic_data, hass_ws_client): @@ -56,3 +58,75 @@ async def test_websocket_api(hass, generic_data, hass_ws_client): assert result["received_packets"] == 3594 assert result["received_dup_packets"] == 12 assert result["received_unsolicited"] == 3546 + + # Test node metadata + await client.send_json({ID: 8, TYPE: "ozw/node_metadata", NODE_ID: 39}) + msg = await client.receive_json() + result = msg["result"] + assert result["metadata"]["ProductPic"] == "images/aeotec/zwa002.png" + + +async def test_refresh_node(hass, generic_data, sent_messages, hass_ws_client): + """Test the ozw refresh node api.""" + receive_message = await setup_ozw(hass, fixture=generic_data) + client = await hass_ws_client(hass) + + # Send the refresh_node_info command + await client.send_json({ID: 9, TYPE: "ozw/refresh_node_info", NODE_ID: 39}) + msg = await client.receive_json() + + assert len(sent_messages) == 1 + assert msg["success"] + + # Receive a mock status update from OZW + message = MQTTMessage( + topic="OpenZWave/1/node/39/", + payload={"NodeID": 39, "NodeQueryStage": "initializing"}, + ) + message.encode() + receive_message(message) + + # Verify we got expected data on the websocket + msg = await client.receive_json() + result = msg["event"] + assert result["type"] == "node_updated" + assert result["node_query_stage"] == "initializing" + + # Send another mock status update from OZW + message = MQTTMessage( + topic="OpenZWave/1/node/39/", + payload={"NodeID": 39, "NodeQueryStage": "versions"}, + ) + message.encode() + receive_message(message) + + # Send a mock status update for a different node + message = MQTTMessage( + topic="OpenZWave/1/node/35/", + payload={"NodeID": 35, "NodeQueryStage": "fake_shouldnt_be_received"}, + ) + message.encode() + receive_message(message) + + # Verify we received the message for node 39 but not for node 35 + msg = await client.receive_json() + result = msg["event"] + assert result["type"] == "node_updated" + assert result["node_query_stage"] == "versions" + + +async def test_refresh_node_unsubscribe(hass, generic_data, hass_ws_client): + """Test unsubscribing the ozw refresh node api.""" + await setup_ozw(hass, fixture=generic_data) + client = await hass_ws_client(hass) + + with patch("openzwavemqtt.OZWOptions.listen") as mock_listen: + # Send the refresh_node_info command + await client.send_json({ID: 9, TYPE: "ozw/refresh_node_info", NODE_ID: 39}) + await client.receive_json() + + # Send the unsubscribe command + await client.send_json({ID: 10, TYPE: "unsubscribe_events", "subscription": 9}) + await client.receive_json() + + assert mock_listen.return_value.called