Update zwave_js FirmwareUploadView to support controller updates (#87239)

* Update zwave_js FirmwareUploadView to support controller updates

* Add coverage

* Change None check to assertion
This commit is contained in:
Raman Gupta 2023-02-22 11:52:00 -05:00 committed by GitHub
parent 5683d21931
commit 1f9f6ab1f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 19 deletions

View File

@ -27,13 +27,14 @@ from zwave_js_server.exceptions import (
NotFoundError, NotFoundError,
SetValueFailed, SetValueFailed,
) )
from zwave_js_server.firmware import update_firmware from zwave_js_server.firmware import controller_firmware_update_otw, update_firmware
from zwave_js_server.model.controller import ( from zwave_js_server.model.controller import (
ControllerStatistics, ControllerStatistics,
InclusionGrant, InclusionGrant,
ProvisioningEntry, ProvisioningEntry,
QRProvisioningInformation, QRProvisioningInformation,
) )
from zwave_js_server.model.controller.firmware import ControllerFirmwareUpdateData
from zwave_js_server.model.driver import Driver from zwave_js_server.model.driver import Driver
from zwave_js_server.model.log_config import LogConfig from zwave_js_server.model.log_config import LogConfig
from zwave_js_server.model.log_message import LogMessage from zwave_js_server.model.log_message import LogMessage
@ -445,7 +446,7 @@ def async_register_api(hass: HomeAssistant) -> None:
hass, websocket_subscribe_controller_statistics hass, websocket_subscribe_controller_statistics
) )
websocket_api.async_register_command(hass, websocket_subscribe_node_statistics) websocket_api.async_register_command(hass, websocket_subscribe_node_statistics)
hass.http.register_view(FirmwareUploadView()) hass.http.register_view(FirmwareUploadView(dr.async_get(hass)))
@websocket_api.require_admin @websocket_api.require_admin
@ -2071,10 +2072,10 @@ class FirmwareUploadView(HomeAssistantView):
url = r"/api/zwave_js/firmware/upload/{device_id}" url = r"/api/zwave_js/firmware/upload/{device_id}"
name = "api:zwave_js:firmware:upload" name = "api:zwave_js:firmware:upload"
def __init__(self) -> None: def __init__(self, dev_reg: dr.DeviceRegistry) -> None:
"""Initialize view.""" """Initialize view."""
super().__init__() super().__init__()
self._dev_reg: dr.DeviceRegistry | None = None self._dev_reg = dev_reg
async def post(self, request: web.Request, device_id: str) -> web.Response: async def post(self, request: web.Request, device_id: str) -> web.Response:
"""Handle upload.""" """Handle upload."""
@ -2083,12 +2084,16 @@ class FirmwareUploadView(HomeAssistantView):
hass = request.app["hass"] hass = request.app["hass"]
try: try:
node = async_get_node_from_device_id(hass, device_id) node = async_get_node_from_device_id(hass, device_id, self._dev_reg)
except ValueError as err: except ValueError as err:
if "not loaded" in err.args[0]: if "not loaded" in err.args[0]:
raise web_exceptions.HTTPBadRequest raise web_exceptions.HTTPBadRequest
raise web_exceptions.HTTPNotFound raise web_exceptions.HTTPNotFound
# If this was not true, we wouldn't have been able to get the node from the
# device ID above
assert node.client.driver
# Increase max payload # Increase max payload
request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access
@ -2100,18 +2105,29 @@ class FirmwareUploadView(HomeAssistantView):
uploaded_file: web_request.FileField = data["file"] uploaded_file: web_request.FileField = data["file"]
try: try:
await update_firmware( if node.client.driver.controller.own_node == node:
node.client.ws_server_url, await controller_firmware_update_otw(
node, node.client.ws_server_url,
[ ControllerFirmwareUpdateData(
NodeFirmwareUpdateData(
uploaded_file.filename, uploaded_file.filename,
await hass.async_add_executor_job(uploaded_file.file.read), await hass.async_add_executor_job(uploaded_file.file.read),
) ),
], async_get_clientsession(hass),
async_get_clientsession(hass), additional_user_agent_components=USER_AGENT,
additional_user_agent_components=USER_AGENT, )
) else:
await update_firmware(
node.client.ws_server_url,
node,
[
NodeFirmwareUpdateData(
uploaded_file.filename,
await hass.async_add_executor_job(uploaded_file.file.read),
)
],
async_get_clientsession(hass),
additional_user_agent_components=USER_AGENT,
)
except BaseZwaveJSServerError as err: except BaseZwaveJSServerError as err:
raise web_exceptions.HTTPBadRequest(reason=str(err)) from err raise web_exceptions.HTTPBadRequest(reason=str(err)) from err

View File

@ -28,6 +28,7 @@ from zwave_js_server.model.controller import (
ProvisioningEntry, ProvisioningEntry,
QRProvisioningInformation, QRProvisioningInformation,
) )
from zwave_js_server.model.controller.firmware import ControllerFirmwareUpdateData
from zwave_js_server.model.node import Node from zwave_js_server.model.node import Node
from zwave_js_server.model.node.firmware import NodeFirmwareUpdateData from zwave_js_server.model.node.firmware import NodeFirmwareUpdateData
@ -84,7 +85,7 @@ from tests.common import MockUser
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
def get_device(hass, node): def get_device(hass: HomeAssistant, node):
"""Get device ID for a node.""" """Get device ID for a node."""
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device_id = get_device_id(node.client.driver, node) device_id = get_device_id(node.client.driver, node)
@ -2968,7 +2969,9 @@ async def test_firmware_upload_view(
device = get_device(hass, multisensor_6) device = get_device(hass, multisensor_6)
with patch( with patch(
"homeassistant.components.zwave_js.api.update_firmware", "homeassistant.components.zwave_js.api.update_firmware",
) as mock_cmd, patch.dict( ) as mock_node_cmd, patch(
"homeassistant.components.zwave_js.api.controller_firmware_update_otw",
) as mock_controller_cmd, patch.dict(
"homeassistant.components.zwave_js.api.USER_AGENT", "homeassistant.components.zwave_js.api.USER_AGENT",
{"HomeAssistant": "0.0.0"}, {"HomeAssistant": "0.0.0"},
): ):
@ -2976,11 +2979,40 @@ async def test_firmware_upload_view(
f"/api/zwave_js/firmware/upload/{device.id}", f"/api/zwave_js/firmware/upload/{device.id}",
data={"file": firmware_file}, data={"file": firmware_file},
) )
assert mock_cmd.call_args[0][1:3] == ( mock_controller_cmd.assert_not_called()
assert mock_node_cmd.call_args[0][1:3] == (
multisensor_6, multisensor_6,
[NodeFirmwareUpdateData("file", bytes(10))], [NodeFirmwareUpdateData("file", bytes(10))],
) )
assert mock_cmd.call_args[1] == { assert mock_node_cmd.call_args[1] == {
"additional_user_agent_components": {"HomeAssistant": "0.0.0"},
}
assert json.loads(await resp.text()) is None
async def test_firmware_upload_view_controller(
hass, client, integration, hass_client: ClientSessionGenerator, firmware_file
) -> None:
"""Test the HTTP firmware upload view for a controller."""
hass_client = await hass_client()
device = get_device(hass, client.driver.controller.nodes[1])
with patch(
"homeassistant.components.zwave_js.api.update_firmware",
) as mock_node_cmd, patch(
"homeassistant.components.zwave_js.api.controller_firmware_update_otw",
) as mock_controller_cmd, patch.dict(
"homeassistant.components.zwave_js.api.USER_AGENT",
{"HomeAssistant": "0.0.0"},
):
resp = await hass_client.post(
f"/api/zwave_js/firmware/upload/{device.id}",
data={"file": firmware_file},
)
mock_node_cmd.assert_not_called()
assert mock_controller_cmd.call_args[0][1:2] == (
ControllerFirmwareUpdateData("file", bytes(10)),
)
assert mock_controller_cmd.call_args[1] == {
"additional_user_agent_components": {"HomeAssistant": "0.0.0"}, "additional_user_agent_components": {"HomeAssistant": "0.0.0"},
} }
assert json.loads(await resp.text()) is None assert json.loads(await resp.text()) is None
@ -3020,6 +3052,24 @@ async def test_firmware_upload_view_invalid_payload(
assert resp.status == HTTPStatus.BAD_REQUEST assert resp.status == HTTPStatus.BAD_REQUEST
async def test_firmware_upload_view_no_driver(
hass: HomeAssistant,
client,
multisensor_6,
integration,
hass_client: ClientSessionGenerator,
) -> None:
"""Test the HTTP firmware upload view when the driver doesn't exist."""
device = get_device(hass, multisensor_6)
client.driver = None
aiohttp_client = await hass_client()
resp = await aiohttp_client.post(
f"/api/zwave_js/firmware/upload/{device.id}",
data={"wrong_key": bytes(10)},
)
assert resp.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize( @pytest.mark.parametrize(
("method", "url"), ("method", "url"),
[("post", "/api/zwave_js/firmware/upload/{}")], [("post", "/api/zwave_js/firmware/upload/{}")],