From 05a7df5629f93dd21dff26b5a85ec3628d5cbcd8 Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Wed, 22 Feb 2023 12:08:57 -0500 Subject: [PATCH] Add controller support to `zwave_js/subscribe_firmware_update_status` (#87348) --- homeassistant/components/zwave_js/api.py | 93 +++++++++++++++--- tests/components/zwave_js/test_api.py | 115 +++++++++++++++++++++++ 2 files changed, 195 insertions(+), 13 deletions(-) diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index 4051c1f8f7a..2612d2d4f68 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -34,7 +34,11 @@ from zwave_js_server.model.controller import ( ProvisioningEntry, QRProvisioningInformation, ) -from zwave_js_server.model.controller.firmware import ControllerFirmwareUpdateData +from zwave_js_server.model.controller.firmware import ( + ControllerFirmwareUpdateData, + ControllerFirmwareUpdateProgress, + ControllerFirmwareUpdateResult, +) from zwave_js_server.model.driver import Driver from zwave_js_server.model.log_config import LogConfig from zwave_js_server.model.log_message import LogMessage @@ -1939,10 +1943,10 @@ async def websocket_is_node_firmware_update_in_progress( connection.send_result(msg[ID], await node.async_is_firmware_update_in_progress()) -def _get_firmware_update_progress_dict( +def _get_node_firmware_update_progress_dict( progress: NodeFirmwareUpdateProgress, ) -> dict[str, int | float]: - """Get a dictionary of firmware update progress.""" + """Get a dictionary of a node's firmware update progress.""" return { "current_file": progress.current_file, "total_files": progress.total_files, @@ -1952,6 +1956,19 @@ def _get_firmware_update_progress_dict( } +def _get_controller_firmware_update_progress_dict( + progress: ControllerFirmwareUpdateProgress, +) -> dict[str, int | float]: + """Get a dictionary of a controller's firmware update progress.""" + return { + "current_file": 1, + "total_files": 1, + "sent_fragments": progress.sent_fragments, + "total_fragments": progress.total_fragments, + "progress": progress.progress, + } + + @websocket_api.require_admin @websocket_api.websocket_command( { @@ -1968,6 +1985,8 @@ async def websocket_subscribe_firmware_update_status( node: Node, ) -> None: """Subscribe to the status of a firmware update.""" + assert node.client.driver + controller = node.client.driver.controller @callback def async_cleanup() -> None: @@ -1976,20 +1995,20 @@ async def websocket_subscribe_firmware_update_status( unsub() @callback - def forward_progress(event: dict) -> None: + def forward_node_progress(event: dict) -> None: progress: NodeFirmwareUpdateProgress = event["firmware_update_progress"] connection.send_message( websocket_api.event_message( msg[ID], { "event": event["event"], - **_get_firmware_update_progress_dict(progress), + **_get_node_firmware_update_progress_dict(progress), }, ) ) @callback - def forward_finished(event: dict) -> None: + def forward_node_finished(event: dict) -> None: finished: NodeFirmwareUpdateResult = event["firmware_update_finished"] connection.send_message( websocket_api.event_message( @@ -2004,21 +2023,69 @@ async def websocket_subscribe_firmware_update_status( ) ) - msg[DATA_UNSUBSCRIBE] = unsubs = [ - node.on("firmware update progress", forward_progress), - node.on("firmware update finished", forward_finished), - ] + @callback + def forward_controller_progress(event: dict) -> None: + progress: ControllerFirmwareUpdateProgress = event["firmware_update_progress"] + connection.send_message( + websocket_api.event_message( + msg[ID], + { + "event": event["event"], + **_get_controller_firmware_update_progress_dict(progress), + }, + ) + ) + + @callback + def forward_controller_finished(event: dict) -> None: + finished: ControllerFirmwareUpdateResult = event["firmware_update_finished"] + connection.send_message( + websocket_api.event_message( + msg[ID], + { + "event": event["event"], + "status": finished.status, + "success": finished.success, + }, + ) + ) + + if controller.own_node == node: + msg[DATA_UNSUBSCRIBE] = unsubs = [ + controller.on("firmware update progress", forward_controller_progress), + controller.on("firmware update finished", forward_controller_finished), + ] + else: + msg[DATA_UNSUBSCRIBE] = unsubs = [ + node.on("firmware update progress", forward_node_progress), + node.on("firmware update finished", forward_node_finished), + ] connection.subscriptions[msg["id"]] = async_cleanup - progress = node.firmware_update_progress connection.send_result(msg[ID]) - if progress: + if node.is_controller_node and ( + controller_progress := controller.firmware_update_progress + ): connection.send_message( websocket_api.event_message( msg[ID], { "event": "firmware update progress", - **_get_firmware_update_progress_dict(progress), + **_get_controller_firmware_update_progress_dict( + controller_progress + ), + }, + ) + ) + elif controller.own_node != node and ( + node_progress := node.firmware_update_progress + ): + connection.send_message( + websocket_api.event_message( + msg[ID], + { + "event": "firmware update progress", + **_get_node_firmware_update_progress_dict(node_progress), }, ) ) diff --git a/tests/components/zwave_js/test_api.py b/tests/components/zwave_js/test_api.py index 9de99be4c0c..f988e72e70b 100644 --- a/tests/components/zwave_js/test_api.py +++ b/tests/components/zwave_js/test_api.py @@ -3872,6 +3872,121 @@ async def test_subscribe_firmware_update_status_initial_value( } +async def test_subscribe_controller_firmware_update_status( + hass, integration, client, hass_ws_client +): + """Test the subscribe_firmware_update_status websocket command for a node.""" + ws_client = await hass_ws_client(hass) + device = get_device(hass, client.driver.controller.nodes[1]) + + client.async_send_command_no_wait.return_value = {} + + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/subscribe_firmware_update_status", + DEVICE_ID: device.id, + } + ) + + msg = await ws_client.receive_json() + assert msg["success"] + assert msg["result"] is None + + event = Event( + type="firmware update progress", + data={ + "source": "controller", + "event": "firmware update progress", + "progress": { + "sentFragments": 1, + "totalFragments": 10, + "progress": 10.0, + }, + }, + ) + client.driver.controller.receive_event(event) + + msg = await ws_client.receive_json() + assert msg["event"] == { + "event": "firmware update progress", + "current_file": 1, + "total_files": 1, + "sent_fragments": 1, + "total_fragments": 10, + "progress": 10.0, + } + + event = Event( + type="firmware update finished", + data={ + "source": "controller", + "event": "firmware update finished", + "result": { + "status": 255, + "success": True, + }, + }, + ) + client.driver.controller.receive_event(event) + + msg = await ws_client.receive_json() + assert msg["event"] == { + "event": "firmware update finished", + "status": 255, + "success": True, + } + + +async def test_subscribe_controller_firmware_update_status_initial_value( + hass, client, integration, hass_ws_client +): + """Test subscribe_firmware_update_status cmd with in progress update for node.""" + ws_client = await hass_ws_client(hass) + device = get_device(hass, client.driver.controller.nodes[1]) + + assert client.driver.controller.firmware_update_progress is None + + # Send a firmware update progress event before the WS command + event = Event( + type="firmware update progress", + data={ + "source": "controller", + "event": "firmware update progress", + "progress": { + "sentFragments": 1, + "totalFragments": 10, + "progress": 10.0, + }, + }, + ) + client.driver.controller.receive_event(event) + + client.async_send_command_no_wait.return_value = {} + + await ws_client.send_json( + { + ID: 1, + TYPE: "zwave_js/subscribe_firmware_update_status", + DEVICE_ID: device.id, + } + ) + + msg = await ws_client.receive_json() + assert msg["success"] + assert msg["result"] is None + + msg = await ws_client.receive_json() + assert msg["event"] == { + "event": "firmware update progress", + "current_file": 1, + "total_files": 1, + "sent_fragments": 1, + "total_fragments": 10, + "progress": 10.0, + } + + async def test_subscribe_firmware_update_status_failures( hass: HomeAssistant, multisensor_6,