diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index be4386e529e..3fd443e5643 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -10,10 +10,11 @@ from aiohttp import hdrs, web, web_exceptions import voluptuous as vol from zwave_js_server import dump from zwave_js_server.client import Client -from zwave_js_server.const import LogLevel +from zwave_js_server.const import CommandClass, LogLevel from zwave_js_server.exceptions import InvalidNewValue, NotFoundError, SetValueFailed from zwave_js_server.model.log_config import LogConfig from zwave_js_server.model.log_message import LogMessage +from zwave_js_server.model.node import Node from zwave_js_server.util.node import async_set_config_parameter from homeassistant.components import websocket_api @@ -44,6 +45,7 @@ from .helpers import async_enable_statistics, update_data_collection_preference ID = "id" ENTRY_ID = "entry_id" NODE_ID = "node_id" +COMMAND_CLASS_ID = "command_class_id" TYPE = "type" PROPERTY = "property" PROPERTY_KEY = "property_key" @@ -76,12 +78,41 @@ def async_get_entry(orig_func: Callable) -> Callable: """Provide user specific data and store to function.""" entry_id = msg[ENTRY_ID] entry = hass.config_entries.async_get_entry(entry_id) + if entry is None: + connection.send_error( + msg[ID], ERR_NOT_FOUND, f"Config entry {entry_id} not found" + ) + return client = hass.data[DOMAIN][entry_id][DATA_CLIENT] await orig_func(hass, connection, msg, entry, client) return async_get_entry_func +def async_get_node(orig_func: Callable) -> Callable: + """Decorate async function to get node.""" + + @async_get_entry + @wraps(orig_func) + async def async_get_node_func( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict, + entry: ConfigEntry, + client: Client, + ) -> None: + """Provide user specific data and store to function.""" + node_id = msg[NODE_ID] + node = client.driver.controller.nodes.get(node_id) + + if node is None: + connection.send_error(msg[ID], ERR_NOT_FOUND, f"Node {node_id} not found") + return + await orig_func(hass, connection, msg, node) + + return async_get_node_func + + @callback def async_register_api(hass: HomeAssistant) -> None: """Register all of our api endpoints.""" @@ -92,6 +123,8 @@ def async_register_api(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_remove_node) websocket_api.async_register_command(hass, websocket_stop_exclusion) websocket_api.async_register_command(hass, websocket_refresh_node_info) + websocket_api.async_register_command(hass, websocket_refresh_node_values) + websocket_api.async_register_command(hass, websocket_refresh_node_cc_values) websocket_api.async_register_command(hass, websocket_subscribe_logs) websocket_api.async_register_command(hass, websocket_update_log_config) websocket_api.async_register_command(hass, websocket_get_log_config) @@ -359,22 +392,14 @@ async def websocket_remove_node( vol.Required(NODE_ID): int, }, ) -@async_get_entry +@async_get_node async def websocket_refresh_node_info( hass: HomeAssistant, connection: ActiveConnection, msg: dict, - entry: ConfigEntry, - client: Client, + node: Node, ) -> None: """Re-interview a node.""" - node_id = msg[NODE_ID] - controller = client.driver.controller - node = controller.nodes.get(node_id) - - if node is None: - connection.send_error(msg[ID], ERR_NOT_FOUND, f"Node {node_id} not found") - return @callback def async_cleanup() -> None: @@ -408,6 +433,59 @@ async def websocket_refresh_node_info( connection.send_result(msg[ID], result) +@websocket_api.require_admin # type: ignore +@websocket_api.async_response +@websocket_api.websocket_command( + { + vol.Required(TYPE): "zwave_js/refresh_node_values", + vol.Required(ENTRY_ID): str, + vol.Required(NODE_ID): int, + }, +) +@async_get_node +async def websocket_refresh_node_values( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict, + node: Node, +) -> None: + """Refresh node values.""" + await node.async_refresh_values() + connection.send_result(msg[ID]) + + +@websocket_api.require_admin # type: ignore +@websocket_api.async_response +@websocket_api.websocket_command( + { + vol.Required(TYPE): "zwave_js/refresh_node_cc_values", + vol.Required(ENTRY_ID): str, + vol.Required(NODE_ID): int, + vol.Required(COMMAND_CLASS_ID): int, + }, +) +@async_get_node +async def websocket_refresh_node_cc_values( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict, + node: Node, +) -> None: + """Refresh node values for a particular CommandClass.""" + command_class_id = msg[COMMAND_CLASS_ID] + + try: + command_class = CommandClass(command_class_id) + except ValueError: + connection.send_error( + msg[ID], ERR_NOT_FOUND, f"Command class {command_class_id} not found" + ) + return + + await node.async_refresh_cc_values(command_class) + connection.send_result(msg[ID]) + + @websocket_api.require_admin # type:ignore @websocket_api.async_response @websocket_api.websocket_command( @@ -420,20 +498,18 @@ async def websocket_refresh_node_info( vol.Required(VALUE): int, } ) -@async_get_entry +@async_get_node async def websocket_set_config_parameter( hass: HomeAssistant, connection: ActiveConnection, msg: dict, - entry: ConfigEntry, - client: Client, + node: Node, ) -> None: """Set a config parameter value for a Z-Wave node.""" - node_id = msg[NODE_ID] property_ = msg[PROPERTY] property_key = msg.get(PROPERTY_KEY) value = msg[VALUE] - node = client.driver.controller.nodes[node_id] + try: zwave_value, cmd_status = await async_set_config_parameter( node, value, property_, property_key=property_key diff --git a/tests/components/zwave_js/test_api.py b/tests/components/zwave_js/test_api.py index 3fb57a366aa..0a88a8e02ff 100644 --- a/tests/components/zwave_js/test_api.py +++ b/tests/components/zwave_js/test_api.py @@ -8,6 +8,7 @@ from zwave_js_server.exceptions import InvalidNewValue, NotFoundError, SetValueF from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.components.zwave_js.api import ( + COMMAND_CLASS_ID, CONFIG, ENABLED, ENTRY_ID, @@ -294,17 +295,101 @@ async def test_refresh_node_info( client.async_send_command_no_wait.reset_mock() + +async def test_refresh_node_values( + hass, client, integration, hass_ws_client, multisensor_6 +): + """Test that the refresh_node_values WS API call works.""" + entry = integration + ws_client = await hass_ws_client(hass) + + client.async_send_command_no_wait.return_value = None + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/refresh_node_values", + ENTRY_ID: entry.entry_id, + NODE_ID: 52, + } + ) + msg = await ws_client.receive_json() + assert msg["success"] + + assert len(client.async_send_command_no_wait.call_args_list) == 1 + args = client.async_send_command_no_wait.call_args[0][0] + assert args["command"] == "node.refresh_values" + assert args["nodeId"] == 52 + + client.async_send_command_no_wait.reset_mock() + + # Test getting non-existent node fails await ws_client.send_json( { ID: 2, - TYPE: "zwave_js/refresh_node_info", + TYPE: "zwave_js/refresh_node_values", ENTRY_ID: entry.entry_id, - NODE_ID: 999, + NODE_ID: 99999, } ) msg = await ws_client.receive_json() assert not msg["success"] - assert msg["error"]["code"] == "not_found" + assert msg["error"]["code"] == ERR_NOT_FOUND + + # Test getting non-existent entry fails + await ws_client.send_json( + { + ID: 3, + TYPE: "zwave_js/refresh_node_values", + ENTRY_ID: "fake_entry_id", + NODE_ID: 52, + } + ) + msg = await ws_client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_FOUND + + +async def test_refresh_node_cc_values( + hass, client, integration, hass_ws_client, multisensor_6 +): + """Test that the refresh_node_cc_values WS API call works.""" + entry = integration + ws_client = await hass_ws_client(hass) + + client.async_send_command_no_wait.return_value = None + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/refresh_node_cc_values", + ENTRY_ID: entry.entry_id, + NODE_ID: 52, + COMMAND_CLASS_ID: 112, + } + ) + msg = await ws_client.receive_json() + assert msg["success"] + + assert len(client.async_send_command_no_wait.call_args_list) == 1 + args = client.async_send_command_no_wait.call_args[0][0] + assert args["command"] == "node.refresh_cc_values" + assert args["nodeId"] == 52 + assert args["commandClass"] == 112 + + client.async_send_command_no_wait.reset_mock() + + # Test using invalid CC ID fails + await ws_client.send_json( + { + ID: 2, + TYPE: "zwave_js/refresh_node_cc_values", + ENTRY_ID: entry.entry_id, + NODE_ID: 52, + COMMAND_CLASS_ID: 9999, + } + ) + msg = await ws_client.receive_json() + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_FOUND async def test_set_config_parameter(