Clean zwave_js services typing (#72485)

Fix services
This commit is contained in:
Martin Hjelmare 2022-05-25 18:39:42 +02:00 committed by GitHub
parent f9f87c607e
commit 10f0509ca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Generator from collections.abc import Generator, Sequence
import logging import logging
from typing import Any from typing import Any
@ -12,7 +12,7 @@ from zwave_js_server.const import CommandClass, CommandStatus
from zwave_js_server.exceptions import SetValueFailed from zwave_js_server.exceptions import SetValueFailed
from zwave_js_server.model.endpoint import Endpoint from zwave_js_server.model.endpoint import Endpoint
from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.node import Node as ZwaveNode
from zwave_js_server.model.value import get_value_id from zwave_js_server.model.value import ValueDataType, get_value_id
from zwave_js_server.util.multicast import async_multicast_set_value from zwave_js_server.util.multicast import async_multicast_set_value
from zwave_js_server.util.node import ( from zwave_js_server.util.node import (
async_bulk_set_partial_config_parameters, async_bulk_set_partial_config_parameters,
@ -72,7 +72,7 @@ def broadcast_command(val: dict[str, Any]) -> dict[str, Any]:
def get_valid_responses_from_results( def get_valid_responses_from_results(
zwave_objects: set[ZwaveNode | Endpoint], results: tuple[Any, ...] zwave_objects: Sequence[ZwaveNode | Endpoint], results: Sequence[Any]
) -> Generator[tuple[ZwaveNode | Endpoint, Any], None, None]: ) -> Generator[tuple[ZwaveNode | Endpoint, Any], None, None]:
"""Return valid responses from a list of results.""" """Return valid responses from a list of results."""
for zwave_object, result in zip(zwave_objects, results): for zwave_object, result in zip(zwave_objects, results):
@ -81,8 +81,8 @@ def get_valid_responses_from_results(
def raise_exceptions_from_results( def raise_exceptions_from_results(
zwave_objects: set[ZwaveNode | Endpoint] | tuple[ZwaveNode | str, ...], zwave_objects: Sequence[ZwaveNode | Endpoint],
results: tuple[Any, ...], results: Sequence[Any],
) -> None: ) -> None:
"""Raise list of exceptions from a list of results.""" """Raise list of exceptions from a list of results."""
if errors := [ if errors := [
@ -153,12 +153,20 @@ class ZWaveServices:
first_node = next((node for node in nodes), None) first_node = next((node for node in nodes), None)
if first_node and not all(node.client.driver is not None for node in nodes):
raise vol.Invalid(f"Driver not ready for all nodes: {nodes}")
# If any nodes don't have matching home IDs, we can't run the command because # If any nodes don't have matching home IDs, we can't run the command because
# we can't multicast across multiple networks # we can't multicast across multiple networks
if first_node and any( if (
first_node
and first_node.client.driver # We checked the driver was ready above.
and any(
node.client.driver.controller.home_id node.client.driver.controller.home_id
!= first_node.client.driver.controller.home_id != first_node.client.driver.controller.home_id
for node in nodes for node in nodes
if node.client.driver is not None
)
): ):
raise vol.Invalid( raise vol.Invalid(
"Multicast commands only work on devices in the same network" "Multicast commands only work on devices in the same network"
@ -417,7 +425,8 @@ class ZWaveServices:
), ),
return_exceptions=True, return_exceptions=True,
) )
for node, result in get_valid_responses_from_results(nodes, results): nodes_list = list(nodes)
for node, result in get_valid_responses_from_results(nodes_list, results):
zwave_value = result[0] zwave_value = result[0]
cmd_status = result[1] cmd_status = result[1]
if cmd_status == CommandStatus.ACCEPTED: if cmd_status == CommandStatus.ACCEPTED:
@ -428,7 +437,7 @@ class ZWaveServices:
"%s with value %s. Parameter will be set when the device wakes up" "%s with value %s. Parameter will be set when the device wakes up"
) )
_LOGGER.info(msg, zwave_value, node, new_value) _LOGGER.info(msg, zwave_value, node, new_value)
raise_exceptions_from_results(nodes, results) raise_exceptions_from_results(nodes_list, results)
async def async_bulk_set_partial_config_parameters( async def async_bulk_set_partial_config_parameters(
self, service: ServiceCall self, service: ServiceCall
@ -450,7 +459,8 @@ class ZWaveServices:
return_exceptions=True, return_exceptions=True,
) )
for node, cmd_status in get_valid_responses_from_results(nodes, results): nodes_list = list(nodes)
for node, cmd_status in get_valid_responses_from_results(nodes_list, results):
if cmd_status == CommandStatus.ACCEPTED: if cmd_status == CommandStatus.ACCEPTED:
msg = "Bulk set partials for configuration parameter %s on Node %s" msg = "Bulk set partials for configuration parameter %s on Node %s"
else: else:
@ -461,7 +471,7 @@ class ZWaveServices:
_LOGGER.info(msg, property_, node) _LOGGER.info(msg, property_, node)
raise_exceptions_from_results(nodes, results) raise_exceptions_from_results(nodes_list, results)
async def async_poll_value(self, service: ServiceCall) -> None: async def async_poll_value(self, service: ServiceCall) -> None:
"""Poll value on a node.""" """Poll value on a node."""
@ -477,10 +487,10 @@ class ZWaveServices:
async def async_set_value(self, service: ServiceCall) -> None: async def async_set_value(self, service: ServiceCall) -> None:
"""Set a value on a node.""" """Set a value on a node."""
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES] nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
command_class = service.data[const.ATTR_COMMAND_CLASS] command_class: CommandClass = service.data[const.ATTR_COMMAND_CLASS]
property_ = service.data[const.ATTR_PROPERTY] property_: int | str = service.data[const.ATTR_PROPERTY]
property_key = service.data.get(const.ATTR_PROPERTY_KEY) property_key: int | str | None = service.data.get(const.ATTR_PROPERTY_KEY)
endpoint = service.data.get(const.ATTR_ENDPOINT) endpoint: int | None = service.data.get(const.ATTR_ENDPOINT)
new_value = service.data[const.ATTR_VALUE] new_value = service.data[const.ATTR_VALUE]
wait_for_result = service.data.get(const.ATTR_WAIT_FOR_RESULT) wait_for_result = service.data.get(const.ATTR_WAIT_FOR_RESULT)
options = service.data.get(const.ATTR_OPTIONS) options = service.data.get(const.ATTR_OPTIONS)
@ -515,17 +525,18 @@ class ZWaveServices:
) )
results = await asyncio.gather(*coros, return_exceptions=True) results = await asyncio.gather(*coros, return_exceptions=True)
nodes_list = list(nodes)
# multiple set_values my fail so we will track the entire list # multiple set_values my fail so we will track the entire list
set_value_failed_nodes_list = [] set_value_failed_nodes_list: list[ZwaveNode | Endpoint] = []
for node, success in get_valid_responses_from_results(nodes, results): for node_, success in get_valid_responses_from_results(nodes_list, results):
if success is False: if success is False:
# If we failed to set a value, add node to SetValueFailed exception list # If we failed to set a value, add node to SetValueFailed exception list
set_value_failed_nodes_list.append(node) set_value_failed_nodes_list.append(node_)
# Add the SetValueFailed exception to the results and the nodes to the node # Add the SetValueFailed exception to the results and the nodes to the node
# list. No-op if there are no SetValueFailed exceptions # list. No-op if there are no SetValueFailed exceptions
raise_exceptions_from_results( raise_exceptions_from_results(
(*nodes, *set_value_failed_nodes_list), (*nodes_list, *set_value_failed_nodes_list),
(*results, *([SET_VALUE_FAILED_EXC] * len(set_value_failed_nodes_list))), (*results, *([SET_VALUE_FAILED_EXC] * len(set_value_failed_nodes_list))),
) )
@ -543,17 +554,17 @@ class ZWaveServices:
await self.async_set_value(service) await self.async_set_value(service)
return return
command_class = service.data[const.ATTR_COMMAND_CLASS] command_class: CommandClass = service.data[const.ATTR_COMMAND_CLASS]
property_ = service.data[const.ATTR_PROPERTY] property_: int | str = service.data[const.ATTR_PROPERTY]
property_key = service.data.get(const.ATTR_PROPERTY_KEY) property_key: int | str | None = service.data.get(const.ATTR_PROPERTY_KEY)
endpoint = service.data.get(const.ATTR_ENDPOINT) endpoint: int | None = service.data.get(const.ATTR_ENDPOINT)
value = ValueDataType(commandClass=command_class, property=property_)
if property_key is not None:
value["propertyKey"] = property_key
if endpoint is not None:
value["endpoint"] = endpoint
value = {
"commandClass": command_class,
"property": property_,
"propertyKey": property_key,
"endpoint": endpoint,
}
new_value = service.data[const.ATTR_VALUE] new_value = service.data[const.ATTR_VALUE]
# If there are no nodes, we can assume there is only one config entry due to # If there are no nodes, we can assume there is only one config entry due to
@ -590,7 +601,7 @@ class ZWaveServices:
success = await async_multicast_set_value( success = await async_multicast_set_value(
client=client, client=client,
new_value=new_value, new_value=new_value,
value_data={k: v for k, v in value.items() if v is not None}, value_data=value,
nodes=None if broadcast else list(nodes), nodes=None if broadcast else list(nodes),
options=options, options=options,
) )
@ -627,8 +638,9 @@ class ZWaveServices:
), ),
return_exceptions=True, return_exceptions=True,
) )
endpoints_list = list(endpoints)
for endpoint, result in get_valid_responses_from_results( for endpoint, result in get_valid_responses_from_results(
endpoints, results endpoints_list, results
): ):
_LOGGER.info( _LOGGER.info(
( (
@ -640,7 +652,7 @@ class ZWaveServices:
endpoint, endpoint,
result, result,
) )
raise_exceptions_from_results(endpoints, results) raise_exceptions_from_results(endpoints_list, results)
# If an endpoint is provided, we assume the user wants to call the CC API on # If an endpoint is provided, we assume the user wants to call the CC API on
# that endpoint for all target nodes # that endpoint for all target nodes