diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index 8a08da13bfc..2d0fee54a18 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -6,12 +6,22 @@ from functools import wraps import json from typing import Callable -from aiohttp import hdrs, web, web_exceptions +from aiohttp import hdrs, web, web_exceptions, web_request import voluptuous as vol from zwave_js_server import dump from zwave_js_server.client import Client from zwave_js_server.const import CommandClass, LogLevel -from zwave_js_server.exceptions import InvalidNewValue, NotFoundError, SetValueFailed +from zwave_js_server.exceptions import ( + BaseZwaveJSServerError, + InvalidNewValue, + NotFoundError, + SetValueFailed, +) +from zwave_js_server.firmware import begin_firmware_update +from zwave_js_server.model.firmware import ( + FirmwareUpdateFinished, + FirmwareUpdateProgress, +) from zwave_js_server.model.log_config import LogConfig from zwave_js_server.model.log_message import LogMessage from zwave_js_server.model.node import Node @@ -28,6 +38,7 @@ from homeassistant.components.websocket_api.const import ( from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import CONF_URL from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import Unauthorized from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.device_registry import DeviceEntry @@ -147,7 +158,12 @@ def async_register_api(hass: HomeAssistant) -> None: hass, websocket_update_data_collection_preference ) websocket_api.async_register_command(hass, websocket_data_collection_status) + websocket_api.async_register_command(hass, websocket_abort_firmware_update) + websocket_api.async_register_command( + hass, websocket_subscribe_firmware_update_status + ) hass.http.register_view(DumpView()) + hass.http.register_view(FirmwareUploadView()) @websocket_api.require_admin @@ -1024,3 +1040,131 @@ class DumpView(HomeAssistantView): hdrs.CONTENT_DISPOSITION: 'attachment; filename="zwave_js_dump.json"', }, ) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required(TYPE): "zwave_js/abort_firmware_update", + vol.Required(ENTRY_ID): str, + vol.Required(NODE_ID): int, + } +) +@websocket_api.async_response +@async_get_node +async def websocket_abort_firmware_update( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict, + node: Node, +) -> None: + """Abort a firmware update.""" + await node.async_abort_firmware_update() + connection.send_result(msg[ID]) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required(TYPE): "zwave_js/subscribe_firmware_update_status", + vol.Required(ENTRY_ID): str, + vol.Required(NODE_ID): int, + } +) +@websocket_api.async_response +@async_get_node +async def websocket_subscribe_firmware_update_status( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict, + node: Node, +) -> None: + """Subsribe to the status of a firmware update.""" + + @callback + def async_cleanup() -> None: + """Remove signal listeners.""" + for unsub in unsubs: + unsub() + + @callback + def forward_progress(event: dict) -> None: + progress: FirmwareUpdateProgress = event["firmware_update_progress"] + connection.send_message( + websocket_api.event_message( + msg[ID], + { + "event": event["event"], + "sent_fragments": progress.sent_fragments, + "total_fragments": progress.total_fragments, + }, + ) + ) + + @callback + def forward_finished(event: dict) -> None: + finished: FirmwareUpdateFinished = event["firmware_update_finished"] + connection.send_message( + websocket_api.event_message( + msg[ID], + { + "event": event["event"], + "status": finished.status, + "wait_time": finished.wait_time, + }, + ) + ) + + unsubs = [ + node.on("firmware update progress", forward_progress), + node.on("firmware update finished", forward_finished), + ] + connection.subscriptions[msg["id"]] = async_cleanup + + connection.send_result(msg[ID]) + + +class FirmwareUploadView(HomeAssistantView): + """View to upload firmware.""" + + url = r"/api/zwave_js/firmware/upload/{config_entry_id}/{node_id:\d+}" + name = "api:zwave_js:firmware:upload" + + async def post( + self, request: web.Request, config_entry_id: str, node_id: str + ) -> web.Response: + """Handle upload.""" + if not request["hass_user"].is_admin: + raise Unauthorized() + hass = request.app["hass"] + if config_entry_id not in hass.data[DOMAIN]: + raise web_exceptions.HTTPBadRequest + + entry = hass.config_entries.async_get_entry(config_entry_id) + client = hass.data[DOMAIN][config_entry_id][DATA_CLIENT] + node = client.driver.controller.nodes.get(int(node_id)) + if not node: + raise web_exceptions.HTTPNotFound + + # Increase max payload + request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access + + data = await request.post() + + if "file" not in data or not isinstance(data["file"], web_request.FileField): + raise web_exceptions.HTTPBadRequest + + uploaded_file: web_request.FileField = data["file"] + + try: + await begin_firmware_update( + entry.data[CONF_URL], + node, + uploaded_file.filename, + await hass.async_add_executor_job(uploaded_file.file.read), + async_get_clientsession(hass), + ) + except BaseZwaveJSServerError as err: + raise web_exceptions.HTTPBadRequest from err + + return self.json(None) diff --git a/tests/components/zwave_js/conftest.py b/tests/components/zwave_js/conftest.py index 2b6abacbf91..a2a712b59f1 100644 --- a/tests/components/zwave_js/conftest.py +++ b/tests/components/zwave_js/conftest.py @@ -1,6 +1,7 @@ """Provide common Z-Wave JS fixtures.""" import asyncio import copy +import io import json from unittest.mock import AsyncMock, patch @@ -717,3 +718,9 @@ def wallmote_central_scene_fixture(client, wallmote_central_scene_state): node = Node(client, copy.deepcopy(wallmote_central_scene_state)) client.driver.controller.nodes[node.node_id] = node return node + + +@pytest.fixture(name="firmware_file") +def firmware_file_fixture(): + """Return mock firmware file stream.""" + return io.BytesIO(bytes(10)) diff --git a/tests/components/zwave_js/test_api.py b/tests/components/zwave_js/test_api.py index 596ed9c0ed9..fd6161b6f00 100644 --- a/tests/components/zwave_js/test_api.py +++ b/tests/components/zwave_js/test_api.py @@ -2,9 +2,15 @@ import json from unittest.mock import patch +import pytest from zwave_js_server.const import LogLevel from zwave_js_server.event import Event -from zwave_js_server.exceptions import InvalidNewValue, NotFoundError, SetValueFailed +from zwave_js_server.exceptions import ( + FailedCommand, + InvalidNewValue, + NotFoundError, + SetValueFailed, +) from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.components.zwave_js.api import ( @@ -1123,13 +1129,74 @@ async def test_dump_view(integration, hass_client): assert json.loads(await resp.text()) == [{"hello": "world"}, {"second": "msg"}] -async def test_dump_view_invalid_entry_id(integration, hass_client): +async def test_firmware_upload_view( + hass, multisensor_6, integration, hass_client, firmware_file +): + """Test the HTTP firmware upload view.""" + client = await hass_client() + with patch( + "homeassistant.components.zwave_js.api.begin_firmware_update", + ) as mock_cmd: + resp = await client.post( + f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", + data={"file": firmware_file}, + ) + assert mock_cmd.call_args[0][1:4] == (multisensor_6, "file", bytes(10)) + assert json.loads(await resp.text()) is None + + +async def test_firmware_upload_view_failed_command( + hass, multisensor_6, integration, hass_client, firmware_file +): + """Test failed command for the HTTP firmware upload view.""" + client = await hass_client() + with patch( + "homeassistant.components.zwave_js.api.begin_firmware_update", + side_effect=FailedCommand("test", "test"), + ): + resp = await client.post( + f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", + data={"file": firmware_file}, + ) + assert resp.status == 400 + + +async def test_firmware_upload_view_invalid_payload( + hass, multisensor_6, integration, hass_client +): + """Test an invalid payload for the HTTP firmware upload view.""" + client = await hass_client() + resp = await client.post( + f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", + data={"wrong_key": bytes(10)}, + ) + assert resp.status == 400 + + +@pytest.mark.parametrize( + "method, url", + [ + ("get", "/api/zwave_js/dump/INVALID"), + ("post", "/api/zwave_js/firmware/upload/INVALID/1"), + ], +) +async def test_view_invalid_entry_id(integration, hass_client, method, url): """Test an invalid config entry id parameter.""" client = await hass_client() - resp = await client.get("/api/zwave_js/dump/INVALID") + resp = await client.request(method, url) assert resp.status == 400 +@pytest.mark.parametrize( + "method, url", [("post", "/api/zwave_js/firmware/upload/{}/111")] +) +async def test_view_invalid_node_id(integration, hass_client, method, url): + """Test an invalid config entry id parameter.""" + client = await hass_client() + resp = await client.request(method, url.format(integration.entry_id)) + assert resp.status == 404 + + async def test_subscribe_logs(hass, integration, client, hass_ws_client): """Test the subscribe_logs websocket command.""" entry = integration @@ -1468,3 +1535,192 @@ async def test_data_collection(hass, client, integration, hass_ws_client): assert not msg["success"] assert msg["error"]["code"] == ERR_NOT_LOADED + + +async def test_abort_firmware_update( + hass, client, multisensor_6, integration, hass_ws_client +): + """Test that the abort_firmware_update WS API call works.""" + entry = integration + ws_client = await hass_ws_client(hass) + + client.async_send_command_no_wait.return_value = {} + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/abort_firmware_update", + ENTRY_ID: entry.entry_id, + NODE_ID: multisensor_6.node_id, + } + ) + msg = await ws_client.receive_json() + assert msg["success"] + + assert len(client.async_send_command_no_wait.call_args_list) == 1 + args = client.async_send_command_no_wait.call_args[0][0] + assert args["command"] == "node.abort_firmware_update" + assert args["nodeId"] == multisensor_6.node_id + + +async def test_abort_firmware_update_failures( + hass, integration, multisensor_6, client, hass_ws_client +): + """Test failures for the abort_firmware_update websocket command.""" + entry = integration + ws_client = await hass_ws_client(hass) + # Test sending command with improper entry ID fails + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/abort_firmware_update", + ENTRY_ID: "fake_entry_id", + NODE_ID: multisensor_6.node_id, + } + ) + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_FOUND + + # Test sending command with improper node ID fails + await ws_client.send_json( + { + ID: 2, + TYPE: "zwave_js/abort_firmware_update", + ENTRY_ID: entry.entry_id, + NODE_ID: multisensor_6.node_id + 100, + } + ) + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_FOUND + + # Test sending command with not loaded entry fails + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + await ws_client.send_json( + { + ID: 3, + TYPE: "zwave_js/abort_firmware_update", + ENTRY_ID: entry.entry_id, + NODE_ID: multisensor_6.node_id, + } + ) + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_LOADED + + +async def test_subscribe_firmware_update_status( + hass, integration, multisensor_6, client, hass_ws_client +): + """Test the subscribe_firmware_update_status websocket command.""" + entry = integration + ws_client = await hass_ws_client(hass) + + client.async_send_command_no_wait.return_value = {} + + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/subscribe_firmware_update_status", + ENTRY_ID: entry.entry_id, + NODE_ID: multisensor_6.node_id, + } + ) + + msg = await ws_client.receive_json() + assert msg["success"] + + event = Event( + type="firmware update progress", + data={ + "source": "node", + "event": "firmware update progress", + "nodeId": multisensor_6.node_id, + "sentFragments": 1, + "totalFragments": 10, + }, + ) + multisensor_6.receive_event(event) + + msg = await ws_client.receive_json() + assert msg["event"] == { + "event": "firmware update progress", + "sent_fragments": 1, + "total_fragments": 10, + } + + event = Event( + type="firmware update finished", + data={ + "source": "node", + "event": "firmware update finished", + "nodeId": multisensor_6.node_id, + "status": 255, + "waitTime": 10, + }, + ) + multisensor_6.receive_event(event) + + msg = await ws_client.receive_json() + assert msg["event"] == { + "event": "firmware update finished", + "status": 255, + "wait_time": 10, + } + + +async def test_subscribe_firmware_update_status_failures( + hass, integration, multisensor_6, client, hass_ws_client +): + """Test failures for the subscribe_firmware_update_status websocket command.""" + entry = integration + ws_client = await hass_ws_client(hass) + # Test sending command with improper entry ID fails + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/subscribe_firmware_update_status", + ENTRY_ID: "fake_entry_id", + NODE_ID: multisensor_6.node_id, + } + ) + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_FOUND + + # Test sending command with improper node ID fails + await ws_client.send_json( + { + ID: 2, + TYPE: "zwave_js/subscribe_firmware_update_status", + ENTRY_ID: entry.entry_id, + NODE_ID: multisensor_6.node_id + 100, + } + ) + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_FOUND + + # Test sending command with not loaded entry fails + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + await ws_client.send_json( + { + ID: 3, + TYPE: "zwave_js/subscribe_firmware_update_status", + ENTRY_ID: entry.entry_id, + NODE_ID: multisensor_6.node_id, + } + ) + msg = await ws_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == ERR_NOT_LOADED