Switch zwave_js firmware upload view to use device ID (#72219)

* Switch zwave_js firmware upload view to use device ID

* Store device registry in view
This commit is contained in:
Raman Gupta 2022-05-20 01:50:13 -04:00 committed by GitHub
parent 7cad1571a2
commit 5f7594268a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 28 deletions

View File

@ -57,7 +57,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized from homeassistant.exceptions import Unauthorized
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import DeviceEntry import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .config_validation import BITMASK_SCHEMA from .config_validation import BITMASK_SCHEMA
@ -607,7 +607,7 @@ async def websocket_add_node(
) )
@callback @callback
def device_registered(device: DeviceEntry) -> None: def device_registered(device: dr.DeviceEntry) -> None:
device_details = { device_details = {
"name": device.name, "name": device.name,
"id": device.id, "id": device.id,
@ -1108,7 +1108,7 @@ async def websocket_replace_failed_node(
) )
@callback @callback
def device_registered(device: DeviceEntry) -> None: def device_registered(device: dr.DeviceEntry) -> None:
device_details = { device_details = {
"name": device.name, "name": device.name,
"id": device.id, "id": device.id,
@ -1819,25 +1819,37 @@ async def websocket_subscribe_firmware_update_status(
class FirmwareUploadView(HomeAssistantView): class FirmwareUploadView(HomeAssistantView):
"""View to upload firmware.""" """View to upload firmware."""
url = r"/api/zwave_js/firmware/upload/{config_entry_id}/{node_id:\d+}" url = r"/api/zwave_js/firmware/upload/{device_id}"
name = "api:zwave_js:firmware:upload" name = "api:zwave_js:firmware:upload"
async def post( def __init__(self) -> None:
self, request: web.Request, config_entry_id: str, node_id: str """Initialize view."""
) -> web.Response: super().__init__()
self._dev_reg: dr.DeviceRegistry | None = None
async def post(self, request: web.Request, device_id: str) -> web.Response:
"""Handle upload.""" """Handle upload."""
if not request["hass_user"].is_admin: if not request["hass_user"].is_admin:
raise Unauthorized() raise Unauthorized()
hass = request.app["hass"] hass = request.app["hass"]
if config_entry_id not in hass.data[DOMAIN]:
raise web_exceptions.HTTPBadRequest
entry = hass.config_entries.async_get_entry(config_entry_id) try:
client: Client = hass.data[DOMAIN][config_entry_id][DATA_CLIENT] node = async_get_node_from_device_id(hass, device_id)
node = client.driver.controller.nodes.get(int(node_id)) except ValueError as err:
if not node: if "not loaded" in err.args[0]:
raise web_exceptions.HTTPBadRequest
raise web_exceptions.HTTPNotFound raise web_exceptions.HTTPNotFound
if not self._dev_reg:
self._dev_reg = dr.async_get(hass)
device = self._dev_reg.async_get(device_id)
assert device
entry = next(
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.entry_id in device.config_entries
)
# Increase max payload # Increase max payload
request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access request._client_max_size = 1024 * 1024 * 10 # pylint: disable=protected-access

View File

@ -2661,11 +2661,12 @@ async def test_firmware_upload_view(
): ):
"""Test the HTTP firmware upload view.""" """Test the HTTP firmware upload view."""
client = await hass_client() client = await hass_client()
device = get_device(hass, multisensor_6)
with patch( with patch(
"homeassistant.components.zwave_js.api.begin_firmware_update", "homeassistant.components.zwave_js.api.begin_firmware_update",
) as mock_cmd: ) as mock_cmd:
resp = await client.post( resp = await client.post(
f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", f"/api/zwave_js/firmware/upload/{device.id}",
data={"file": firmware_file}, data={"file": firmware_file},
) )
assert mock_cmd.call_args[0][1:4] == (multisensor_6, "file", bytes(10)) assert mock_cmd.call_args[0][1:4] == (multisensor_6, "file", bytes(10))
@ -2677,12 +2678,13 @@ async def test_firmware_upload_view_failed_command(
): ):
"""Test failed command for the HTTP firmware upload view.""" """Test failed command for the HTTP firmware upload view."""
client = await hass_client() client = await hass_client()
device = get_device(hass, multisensor_6)
with patch( with patch(
"homeassistant.components.zwave_js.api.begin_firmware_update", "homeassistant.components.zwave_js.api.begin_firmware_update",
side_effect=FailedCommand("test", "test"), side_effect=FailedCommand("test", "test"),
): ):
resp = await client.post( resp = await client.post(
f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", f"/api/zwave_js/firmware/upload/{device.id}",
data={"file": firmware_file}, data={"file": firmware_file},
) )
assert resp.status == HTTPStatus.BAD_REQUEST assert resp.status == HTTPStatus.BAD_REQUEST
@ -2692,9 +2694,10 @@ async def test_firmware_upload_view_invalid_payload(
hass, multisensor_6, integration, hass_client hass, multisensor_6, integration, hass_client
): ):
"""Test an invalid payload for the HTTP firmware upload view.""" """Test an invalid payload for the HTTP firmware upload view."""
device = get_device(hass, multisensor_6)
client = await hass_client() client = await hass_client()
resp = await client.post( resp = await client.post(
f"/api/zwave_js/firmware/upload/{integration.entry_id}/{multisensor_6.node_id}", f"/api/zwave_js/firmware/upload/{device.id}",
data={"wrong_key": bytes(10)}, data={"wrong_key": bytes(10)},
) )
assert resp.status == HTTPStatus.BAD_REQUEST assert resp.status == HTTPStatus.BAD_REQUEST
@ -2702,40 +2705,43 @@ async def test_firmware_upload_view_invalid_payload(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"method, url", "method, url",
[("post", "/api/zwave_js/firmware/upload/{}/{}")], [("post", "/api/zwave_js/firmware/upload/{}")],
) )
async def test_node_view_non_admin_user( async def test_node_view_non_admin_user(
multisensor_6, integration, hass_client, hass_admin_user, method, url hass, multisensor_6, integration, hass_client, hass_admin_user, method, url
): ):
"""Test node level views for non-admin users.""" """Test node level views for non-admin users."""
client = await hass_client() client = await hass_client()
device = get_device(hass, multisensor_6)
# Verify we require admin user # Verify we require admin user
hass_admin_user.groups = [] hass_admin_user.groups = []
resp = await client.request( resp = await client.request(method, url.format(device.id))
method, url.format(integration.entry_id, multisensor_6.node_id)
)
assert resp.status == HTTPStatus.UNAUTHORIZED assert resp.status == HTTPStatus.UNAUTHORIZED
@pytest.mark.parametrize( @pytest.mark.parametrize(
"method, url", "method, url",
[ [
("post", "/api/zwave_js/firmware/upload/INVALID/1"), ("post", "/api/zwave_js/firmware/upload/{}"),
], ],
) )
async def test_view_invalid_entry_id(integration, hass_client, method, url): async def test_view_unloaded_config_entry(
"""Test an invalid config entry id parameter.""" hass, multisensor_6, integration, hass_client, method, url
):
"""Test an unloaded config entry raises Bad Request."""
client = await hass_client() client = await hass_client()
resp = await client.request(method, url) device = get_device(hass, multisensor_6)
await hass.config_entries.async_unload(integration.entry_id)
resp = await client.request(method, url.format(device.id))
assert resp.status == HTTPStatus.BAD_REQUEST assert resp.status == HTTPStatus.BAD_REQUEST
@pytest.mark.parametrize( @pytest.mark.parametrize(
"method, url", "method, url",
[("post", "/api/zwave_js/firmware/upload/{}/111")], [("post", "/api/zwave_js/firmware/upload/INVALID")],
) )
async def test_view_invalid_node_id(integration, hass_client, method, url): async def test_view_invalid_device_id(integration, hass_client, method, url):
"""Test an invalid config entry id parameter.""" """Test an invalid device id parameter."""
client = await hass_client() client = await hass_client()
resp = await client.request(method, url.format(integration.entry_id)) resp = await client.request(method, url.format(integration.entry_id))
assert resp.status == HTTPStatus.NOT_FOUND assert resp.status == HTTPStatus.NOT_FOUND