Clean zwave_js services typing (#72485)

Fix services
This commit is contained in:
Martin Hjelmare 2022-05-25 18:39:42 +02:00 committed by GitHub
parent f9f87c607e
commit 10f0509ca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Generator
from collections.abc import Generator, Sequence
import logging
from typing import Any
@ -12,7 +12,7 @@ from zwave_js_server.const import CommandClass, CommandStatus
from zwave_js_server.exceptions import SetValueFailed
from zwave_js_server.model.endpoint import Endpoint
from zwave_js_server.model.node import Node as ZwaveNode
from zwave_js_server.model.value import get_value_id
from zwave_js_server.model.value import ValueDataType, get_value_id
from zwave_js_server.util.multicast import async_multicast_set_value
from zwave_js_server.util.node import (
async_bulk_set_partial_config_parameters,
@ -72,7 +72,7 @@ def broadcast_command(val: dict[str, Any]) -> dict[str, Any]:
def get_valid_responses_from_results(
zwave_objects: set[ZwaveNode | Endpoint], results: tuple[Any, ...]
zwave_objects: Sequence[ZwaveNode | Endpoint], results: Sequence[Any]
) -> Generator[tuple[ZwaveNode | Endpoint, Any], None, None]:
"""Return valid responses from a list of results."""
for zwave_object, result in zip(zwave_objects, results):
@ -81,8 +81,8 @@ def get_valid_responses_from_results(
def raise_exceptions_from_results(
zwave_objects: set[ZwaveNode | Endpoint] | tuple[ZwaveNode | str, ...],
results: tuple[Any, ...],
zwave_objects: Sequence[ZwaveNode | Endpoint],
results: Sequence[Any],
) -> None:
"""Raise list of exceptions from a list of results."""
if errors := [
@ -153,12 +153,20 @@ class ZWaveServices:
first_node = next((node for node in nodes), None)
if first_node and not all(node.client.driver is not None for node in nodes):
raise vol.Invalid(f"Driver not ready for all nodes: {nodes}")
# If any nodes don't have matching home IDs, we can't run the command because
# we can't multicast across multiple networks
if first_node and any(
if (
first_node
and first_node.client.driver # We checked the driver was ready above.
and any(
node.client.driver.controller.home_id
!= first_node.client.driver.controller.home_id
for node in nodes
if node.client.driver is not None
)
):
raise vol.Invalid(
"Multicast commands only work on devices in the same network"
@ -417,7 +425,8 @@ class ZWaveServices:
),
return_exceptions=True,
)
for node, result in get_valid_responses_from_results(nodes, results):
nodes_list = list(nodes)
for node, result in get_valid_responses_from_results(nodes_list, results):
zwave_value = result[0]
cmd_status = result[1]
if cmd_status == CommandStatus.ACCEPTED:
@ -428,7 +437,7 @@ class ZWaveServices:
"%s with value %s. Parameter will be set when the device wakes up"
)
_LOGGER.info(msg, zwave_value, node, new_value)
raise_exceptions_from_results(nodes, results)
raise_exceptions_from_results(nodes_list, results)
async def async_bulk_set_partial_config_parameters(
self, service: ServiceCall
@ -450,7 +459,8 @@ class ZWaveServices:
return_exceptions=True,
)
for node, cmd_status in get_valid_responses_from_results(nodes, results):
nodes_list = list(nodes)
for node, cmd_status in get_valid_responses_from_results(nodes_list, results):
if cmd_status == CommandStatus.ACCEPTED:
msg = "Bulk set partials for configuration parameter %s on Node %s"
else:
@ -461,7 +471,7 @@ class ZWaveServices:
_LOGGER.info(msg, property_, node)
raise_exceptions_from_results(nodes, results)
raise_exceptions_from_results(nodes_list, results)
async def async_poll_value(self, service: ServiceCall) -> None:
"""Poll value on a node."""
@ -477,10 +487,10 @@ class ZWaveServices:
async def async_set_value(self, service: ServiceCall) -> None:
"""Set a value on a node."""
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
command_class = service.data[const.ATTR_COMMAND_CLASS]
property_ = service.data[const.ATTR_PROPERTY]
property_key = service.data.get(const.ATTR_PROPERTY_KEY)
endpoint = service.data.get(const.ATTR_ENDPOINT)
command_class: CommandClass = service.data[const.ATTR_COMMAND_CLASS]
property_: int | str = service.data[const.ATTR_PROPERTY]
property_key: int | str | None = service.data.get(const.ATTR_PROPERTY_KEY)
endpoint: int | None = service.data.get(const.ATTR_ENDPOINT)
new_value = service.data[const.ATTR_VALUE]
wait_for_result = service.data.get(const.ATTR_WAIT_FOR_RESULT)
options = service.data.get(const.ATTR_OPTIONS)
@ -515,17 +525,18 @@ class ZWaveServices:
)
results = await asyncio.gather(*coros, return_exceptions=True)
nodes_list = list(nodes)
# multiple set_values my fail so we will track the entire list
set_value_failed_nodes_list = []
for node, success in get_valid_responses_from_results(nodes, results):
set_value_failed_nodes_list: list[ZwaveNode | Endpoint] = []
for node_, success in get_valid_responses_from_results(nodes_list, results):
if success is False:
# If we failed to set a value, add node to SetValueFailed exception list
set_value_failed_nodes_list.append(node)
set_value_failed_nodes_list.append(node_)
# Add the SetValueFailed exception to the results and the nodes to the node
# list. No-op if there are no SetValueFailed exceptions
raise_exceptions_from_results(
(*nodes, *set_value_failed_nodes_list),
(*nodes_list, *set_value_failed_nodes_list),
(*results, *([SET_VALUE_FAILED_EXC] * len(set_value_failed_nodes_list))),
)
@ -543,17 +554,17 @@ class ZWaveServices:
await self.async_set_value(service)
return
command_class = service.data[const.ATTR_COMMAND_CLASS]
property_ = service.data[const.ATTR_PROPERTY]
property_key = service.data.get(const.ATTR_PROPERTY_KEY)
endpoint = service.data.get(const.ATTR_ENDPOINT)
command_class: CommandClass = service.data[const.ATTR_COMMAND_CLASS]
property_: int | str = service.data[const.ATTR_PROPERTY]
property_key: int | str | None = service.data.get(const.ATTR_PROPERTY_KEY)
endpoint: int | None = service.data.get(const.ATTR_ENDPOINT)
value = ValueDataType(commandClass=command_class, property=property_)
if property_key is not None:
value["propertyKey"] = property_key
if endpoint is not None:
value["endpoint"] = endpoint
value = {
"commandClass": command_class,
"property": property_,
"propertyKey": property_key,
"endpoint": endpoint,
}
new_value = service.data[const.ATTR_VALUE]
# If there are no nodes, we can assume there is only one config entry due to
@ -590,7 +601,7 @@ class ZWaveServices:
success = await async_multicast_set_value(
client=client,
new_value=new_value,
value_data={k: v for k, v in value.items() if v is not None},
value_data=value,
nodes=None if broadcast else list(nodes),
options=options,
)
@ -627,8 +638,9 @@ class ZWaveServices:
),
return_exceptions=True,
)
endpoints_list = list(endpoints)
for endpoint, result in get_valid_responses_from_results(
endpoints, results
endpoints_list, results
):
_LOGGER.info(
(
@ -640,7 +652,7 @@ class ZWaveServices:
endpoint,
result,
)
raise_exceptions_from_results(endpoints, results)
raise_exceptions_from_results(endpoints_list, results)
# If an endpoint is provided, we assume the user wants to call the CC API on
# that endpoint for all target nodes