Parallelize zwave_js service calls (#71662)

* Parallelize zwave_js service calls to speed them up and handle exceptions properly

* Fix bug

* Add tests

* Fix comments

* Additional comment fixes

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Raman Gupta 2022-05-12 03:07:58 -04:00 committed by GitHub
parent 0d69adb404
commit 533257021c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 274 additions and 44 deletions

View File

@ -2,13 +2,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Generator
import logging import logging
from typing import Any, cast from typing import Any
import voluptuous as vol import voluptuous as vol
from zwave_js_server.client import Client as ZwaveClient from zwave_js_server.client import Client as ZwaveClient
from zwave_js_server.const import CommandClass, CommandStatus from zwave_js_server.const import CommandClass, CommandStatus
from zwave_js_server.exceptions import FailedCommand, SetValueFailed from zwave_js_server.exceptions import SetValueFailed
from zwave_js_server.model.endpoint import Endpoint from zwave_js_server.model.endpoint import Endpoint
from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.node import Node as ZwaveNode
from zwave_js_server.model.value import get_value_id from zwave_js_server.model.value import get_value_id
@ -38,6 +39,12 @@ from .helpers import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SET_VALUE_FAILED_EXC = SetValueFailed(
"Unable to set value, refer to "
"https://zwave-js.github.io/node-zwave-js/#/api/node?id=setvalue for "
"possible reasons"
)
def parameter_name_does_not_need_bitmask( def parameter_name_does_not_need_bitmask(
val: dict[str, int | str | list[str]] val: dict[str, int | str | list[str]]
@ -64,6 +71,33 @@ def broadcast_command(val: dict[str, Any]) -> dict[str, Any]:
) )
def get_valid_responses_from_results(
zwave_objects: set[ZwaveNode | Endpoint], results: tuple[Any, ...]
) -> Generator[tuple[ZwaveNode | Endpoint, Any], None, None]:
"""Return valid responses from a list of results."""
for zwave_object, result in zip(zwave_objects, results):
if not isinstance(result, Exception):
yield zwave_object, result
def raise_exceptions_from_results(
zwave_objects: set[ZwaveNode | Endpoint] | tuple[ZwaveNode | str, ...],
results: tuple[Any, ...],
) -> None:
"""Raise list of exceptions from a list of results."""
if errors := [
tup for tup in zip(zwave_objects, results) if isinstance(tup[1], Exception)
]:
lines = (
f"{len(errors)} error(s):",
*(
f"{zwave_object} - {error.__class__.__name__}: {error.args[0]}"
for zwave_object, error in errors
),
)
raise HomeAssistantError("\n".join(lines))
class ZWaveServices: class ZWaveServices:
"""Class that holds our services (Zwave Commands) that should be published to hass.""" """Class that holds our services (Zwave Commands) that should be published to hass."""
@ -371,14 +405,21 @@ class ZWaveServices:
property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK) property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK)
new_value = service.data[const.ATTR_CONFIG_VALUE] new_value = service.data[const.ATTR_CONFIG_VALUE]
for node in nodes: results = await asyncio.gather(
zwave_value, cmd_status = await async_set_config_parameter( *(
node, async_set_config_parameter(
new_value, node,
property_or_property_name, new_value,
property_key=property_key, property_or_property_name,
) property_key=property_key,
)
for node in nodes
),
return_exceptions=True,
)
for node, result in get_valid_responses_from_results(nodes, results):
zwave_value = result[0]
cmd_status = result[1]
if cmd_status == CommandStatus.ACCEPTED: if cmd_status == CommandStatus.ACCEPTED:
msg = "Set configuration parameter %s on Node %s with value %s" msg = "Set configuration parameter %s on Node %s with value %s"
else: else:
@ -386,8 +427,8 @@ class ZWaveServices:
"Added command to queue to set configuration parameter %s on Node " "Added command to queue to set configuration parameter %s on Node "
"%s with value %s. Parameter will be set when the device wakes up" "%s with value %s. Parameter will be set when the device wakes up"
) )
_LOGGER.info(msg, zwave_value, node, new_value) _LOGGER.info(msg, zwave_value, node, new_value)
raise_exceptions_from_results(nodes, results)
async def async_bulk_set_partial_config_parameters( async def async_bulk_set_partial_config_parameters(
self, service: ServiceCall self, service: ServiceCall
@ -397,23 +438,31 @@ class ZWaveServices:
property_ = service.data[const.ATTR_CONFIG_PARAMETER] property_ = service.data[const.ATTR_CONFIG_PARAMETER]
new_value = service.data[const.ATTR_CONFIG_VALUE] new_value = service.data[const.ATTR_CONFIG_VALUE]
for node in nodes: results = await asyncio.gather(
cmd_status = await async_bulk_set_partial_config_parameters( *(
node, async_bulk_set_partial_config_parameters(
property_, node,
new_value, property_,
) new_value,
)
for node in nodes
),
return_exceptions=True,
)
for node, cmd_status in get_valid_responses_from_results(nodes, results):
if cmd_status == CommandStatus.ACCEPTED: if cmd_status == CommandStatus.ACCEPTED:
msg = "Bulk set partials for configuration parameter %s on Node %s" msg = "Bulk set partials for configuration parameter %s on Node %s"
else: else:
msg = ( msg = (
"Added command to queue to bulk set partials for configuration " "Queued command to bulk set partials for configuration parameter "
"parameter %s on Node %s" "%s on Node %s"
) )
_LOGGER.info(msg, property_, node) _LOGGER.info(msg, property_, node)
raise_exceptions_from_results(nodes, results)
async def async_poll_value(self, service: ServiceCall) -> None: async def async_poll_value(self, service: ServiceCall) -> None:
"""Poll value on a node.""" """Poll value on a node."""
for entity_id in service.data[ATTR_ENTITY_ID]: for entity_id in service.data[ATTR_ENTITY_ID]:
@ -436,6 +485,7 @@ class ZWaveServices:
wait_for_result = service.data.get(const.ATTR_WAIT_FOR_RESULT) wait_for_result = service.data.get(const.ATTR_WAIT_FOR_RESULT)
options = service.data.get(const.ATTR_OPTIONS) options = service.data.get(const.ATTR_OPTIONS)
coros = []
for node in nodes: for node in nodes:
value_id = get_value_id( value_id = get_value_id(
node, node,
@ -455,19 +505,29 @@ class ZWaveServices:
new_value_ = str(new_value) new_value_ = str(new_value)
else: else:
new_value_ = new_value new_value_ = new_value
success = await node.async_set_value( coros.append(
value_id, node.async_set_value(
new_value_, value_id,
options=options, new_value_,
wait_for_result=wait_for_result, options=options,
wait_for_result=wait_for_result,
)
) )
results = await asyncio.gather(*coros, return_exceptions=True)
# multiple set_values my fail so we will track the entire list
set_value_failed_nodes_list = []
for node, success in get_valid_responses_from_results(nodes, results):
if success is False: if success is False:
raise HomeAssistantError( # If we failed to set a value, add node to SetValueFailed exception list
"Unable to set value, refer to " set_value_failed_nodes_list.append(node)
"https://zwave-js.github.io/node-zwave-js/#/api/node?id=setvalue "
"for possible reasons" # Add the SetValueFailed exception to the results and the nodes to the node
) from SetValueFailed # list. No-op if there are no SetValueFailed exceptions
raise_exceptions_from_results(
(*nodes, *set_value_failed_nodes_list),
(*results, *([SET_VALUE_FAILED_EXC] * len(set_value_failed_nodes_list))),
)
async def async_multicast_set_value(self, service: ServiceCall) -> None: async def async_multicast_set_value(self, service: ServiceCall) -> None:
"""Set a value via multicast to multiple nodes.""" """Set a value via multicast to multiple nodes."""
@ -556,24 +616,29 @@ class ZWaveServices:
async def _async_invoke_cc_api(endpoints: set[Endpoint]) -> None: async def _async_invoke_cc_api(endpoints: set[Endpoint]) -> None:
"""Invoke the CC API on a node endpoint.""" """Invoke the CC API on a node endpoint."""
errors: list[str] = [] results = await asyncio.gather(
for endpoint in endpoints: *(
endpoint.async_invoke_cc_api(
command_class, method_name, *parameters
)
for endpoint in endpoints
),
return_exceptions=True,
)
for endpoint, result in get_valid_responses_from_results(
endpoints, results
):
_LOGGER.info( _LOGGER.info(
"Invoking %s CC API method %s on endpoint %s", (
"Invoked %s CC API method %s on endpoint %s with the following "
"result: %s"
),
command_class.name, command_class.name,
method_name, method_name,
endpoint, endpoint,
result,
) )
try: raise_exceptions_from_results(endpoints, results)
await endpoint.async_invoke_cc_api(
command_class, method_name, *parameters
)
except FailedCommand as err:
errors.append(cast(str, err.args[0]))
if errors:
raise HomeAssistantError(
"\n".join([f"{len(errors)} error(s):", *errors])
)
# If an endpoint is provided, we assume the user wants to call the CC API on # If an endpoint is provided, we assume the user wants to call the CC API on
# that endpoint for all target nodes # that endpoint for all target nodes

View File

@ -518,6 +518,71 @@ async def test_set_config_parameter(hass, client, multisensor_6, integration):
) )
async def test_set_config_parameter_gather(
hass,
client,
multisensor_6,
climate_radio_thermostat_ct100_plus_different_endpoints,
integration,
):
"""Test the set_config_parameter service gather functionality."""
# Test setting config parameter by property and validate that the first node
# which triggers an error doesn't prevent the second one to be called.
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
DOMAIN,
SERVICE_SET_CONFIG_PARAMETER,
{
ATTR_ENTITY_ID: [
AIR_TEMPERATURE_SENSOR,
CLIMATE_RADIO_THERMOSTAT_ENTITY,
],
ATTR_CONFIG_PARAMETER: 1,
ATTR_CONFIG_VALUE: 1,
},
blocking=True,
)
assert len(client.async_send_command_no_wait.call_args_list) == 0
assert len(client.async_send_command.call_args_list) == 1
args = client.async_send_command.call_args[0][0]
assert args["command"] == "node.set_value"
assert args["nodeId"] == 26
assert args["valueId"] == {
"endpoint": 0,
"commandClass": 112,
"commandClassName": "Configuration",
"property": 1,
"propertyName": "Temperature Reporting Threshold",
"ccVersion": 1,
"metadata": {
"type": "number",
"readable": True,
"writeable": True,
"description": "Reporting threshold for changes in the ambient temperature",
"label": "Temperature Reporting Threshold",
"default": 2,
"min": 0,
"max": 4,
"states": {
"0": "Disabled",
"1": "0.5\u00b0 F",
"2": "1.0\u00b0 F",
"3": "1.5\u00b0 F",
"4": "2.0\u00b0 F",
},
"valueSize": 1,
"format": 0,
"allowManualEntry": False,
"isFromConfig": True,
},
"value": 1,
}
assert args["value"] == 1
client.async_send_command.reset_mock()
async def test_bulk_set_config_parameters(hass, client, multisensor_6, integration): async def test_bulk_set_config_parameters(hass, client, multisensor_6, integration):
"""Test the bulk_set_partial_config_parameters service.""" """Test the bulk_set_partial_config_parameters service."""
dev_reg = async_get_dev_reg(hass) dev_reg = async_get_dev_reg(hass)
@ -726,6 +791,45 @@ async def test_bulk_set_config_parameters(hass, client, multisensor_6, integrati
client.async_send_command.reset_mock() client.async_send_command.reset_mock()
async def test_bulk_set_config_parameters_gather(
hass,
client,
multisensor_6,
climate_radio_thermostat_ct100_plus_different_endpoints,
integration,
):
"""Test the bulk_set_partial_config_parameters service gather functionality."""
# Test bulk setting config parameter by property and validate that the first node
# which triggers an error doesn't prevent the second one to be called.
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
DOMAIN,
SERVICE_BULK_SET_PARTIAL_CONFIG_PARAMETERS,
{
ATTR_ENTITY_ID: [
CLIMATE_RADIO_THERMOSTAT_ENTITY,
AIR_TEMPERATURE_SENSOR,
],
ATTR_CONFIG_PARAMETER: 102,
ATTR_CONFIG_VALUE: 241,
},
blocking=True,
)
assert len(client.async_send_command.call_args_list) == 0
assert len(client.async_send_command_no_wait.call_args_list) == 1
args = client.async_send_command_no_wait.call_args[0][0]
assert args["command"] == "node.set_value"
assert args["nodeId"] == 52
assert args["valueId"] == {
"commandClass": 112,
"property": 102,
}
assert args["value"] == 241
client.async_send_command_no_wait.reset_mock()
async def test_refresh_value( async def test_refresh_value(
hass, client, climate_radio_thermostat_ct100_plus_different_endpoints, integration hass, client, climate_radio_thermostat_ct100_plus_different_endpoints, integration
): ):
@ -1126,6 +1230,66 @@ async def test_set_value_options(hass, client, aeon_smart_switch_6, integration)
client.async_send_command.reset_mock() client.async_send_command.reset_mock()
async def test_set_value_gather(
hass,
client,
multisensor_6,
climate_radio_thermostat_ct100_plus_different_endpoints,
integration,
):
"""Test the set_value service gather functionality."""
# Test setting value by property and validate that the first node
# which triggers an error doesn't prevent the second one to be called.
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
DOMAIN,
SERVICE_SET_VALUE,
{
ATTR_ENTITY_ID: [
CLIMATE_RADIO_THERMOSTAT_ENTITY,
AIR_TEMPERATURE_SENSOR,
],
ATTR_COMMAND_CLASS: 112,
ATTR_PROPERTY: 102,
ATTR_PROPERTY_KEY: 1,
ATTR_VALUE: 1,
},
blocking=True,
)
assert len(client.async_send_command.call_args_list) == 0
assert len(client.async_send_command_no_wait.call_args_list) == 1
args = client.async_send_command_no_wait.call_args[0][0]
assert args["command"] == "node.set_value"
assert args["nodeId"] == 52
assert args["valueId"] == {
"commandClassName": "Configuration",
"commandClass": 112,
"endpoint": 0,
"property": 102,
"propertyKey": 1,
"propertyName": "Group 2: Send battery reports",
"metadata": {
"type": "number",
"readable": True,
"writeable": True,
"valueSize": 4,
"min": 0,
"max": 1,
"default": 1,
"format": 0,
"allowManualEntry": True,
"label": "Group 2: Send battery reports",
"description": "Include battery information in periodic reports to Group 2",
"isFromConfig": True,
},
"value": 0,
}
assert args["value"] == 1
client.async_send_command_no_wait.reset_mock()
async def test_multicast_set_value( async def test_multicast_set_value(
hass, hass,
client, client,
@ -1728,7 +1892,8 @@ async def test_invoke_cc_api(
client.async_send_command.reset_mock() client.async_send_command.reset_mock()
client.async_send_command_no_wait.reset_mock() client.async_send_command_no_wait.reset_mock()
# Test failed invoke_cc_api call on one node # Test failed invoke_cc_api call on one node. We return the error on
# the first node in the call to make sure that gather works as expected
client.async_send_command.return_value = {"response": True} client.async_send_command.return_value = {"response": True}
client.async_send_command_no_wait.side_effect = FailedZWaveCommand( client.async_send_command_no_wait.side_effect = FailedZWaveCommand(
"test", 12, "test" "test", 12, "test"
@ -1740,8 +1905,8 @@ async def test_invoke_cc_api(
SERVICE_INVOKE_CC_API, SERVICE_INVOKE_CC_API,
{ {
ATTR_DEVICE_ID: [ ATTR_DEVICE_ID: [
device_radio_thermostat.id,
device_danfoss.id, device_danfoss.id,
device_radio_thermostat.id,
], ],
ATTR_COMMAND_CLASS: 132, ATTR_COMMAND_CLASS: 132,
ATTR_ENDPOINT: 0, ATTR_ENDPOINT: 0,