From 10f0509ca3a2b1fd186bf7a801883b427ed40fc4 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Wed, 25 May 2022 18:39:42 +0200 Subject: [PATCH] Clean zwave_js services typing (#72485) Fix services --- homeassistant/components/zwave_js/services.py | 80 +++++++++++-------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/homeassistant/components/zwave_js/services.py b/homeassistant/components/zwave_js/services.py index 3b56e0a073c..d60532fcf75 100644 --- a/homeassistant/components/zwave_js/services.py +++ b/homeassistant/components/zwave_js/services.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Generator +from collections.abc import Generator, Sequence import logging from typing import Any @@ -12,7 +12,7 @@ from zwave_js_server.const import CommandClass, CommandStatus from zwave_js_server.exceptions import SetValueFailed from zwave_js_server.model.endpoint import Endpoint from zwave_js_server.model.node import Node as ZwaveNode -from zwave_js_server.model.value import get_value_id +from zwave_js_server.model.value import ValueDataType, get_value_id from zwave_js_server.util.multicast import async_multicast_set_value from zwave_js_server.util.node import ( async_bulk_set_partial_config_parameters, @@ -72,7 +72,7 @@ def broadcast_command(val: dict[str, Any]) -> dict[str, Any]: def get_valid_responses_from_results( - zwave_objects: set[ZwaveNode | Endpoint], results: tuple[Any, ...] + zwave_objects: Sequence[ZwaveNode | Endpoint], results: Sequence[Any] ) -> Generator[tuple[ZwaveNode | Endpoint, Any], None, None]: """Return valid responses from a list of results.""" for zwave_object, result in zip(zwave_objects, results): @@ -81,8 +81,8 @@ def get_valid_responses_from_results( def raise_exceptions_from_results( - zwave_objects: set[ZwaveNode | Endpoint] | tuple[ZwaveNode | str, ...], - results: tuple[Any, ...], + zwave_objects: Sequence[ZwaveNode | Endpoint], + results: Sequence[Any], ) -> None: """Raise list of exceptions from a list of results.""" if errors := [ @@ -153,12 +153,20 @@ class ZWaveServices: first_node = next((node for node in nodes), None) + if first_node and not all(node.client.driver is not None for node in nodes): + raise vol.Invalid(f"Driver not ready for all nodes: {nodes}") + # If any nodes don't have matching home IDs, we can't run the command because # we can't multicast across multiple networks - if first_node and any( - node.client.driver.controller.home_id - != first_node.client.driver.controller.home_id - for node in nodes + if ( + first_node + and first_node.client.driver # We checked the driver was ready above. + and any( + node.client.driver.controller.home_id + != first_node.client.driver.controller.home_id + for node in nodes + if node.client.driver is not None + ) ): raise vol.Invalid( "Multicast commands only work on devices in the same network" @@ -417,7 +425,8 @@ class ZWaveServices: ), return_exceptions=True, ) - for node, result in get_valid_responses_from_results(nodes, results): + nodes_list = list(nodes) + for node, result in get_valid_responses_from_results(nodes_list, results): zwave_value = result[0] cmd_status = result[1] if cmd_status == CommandStatus.ACCEPTED: @@ -428,7 +437,7 @@ class ZWaveServices: "%s with value %s. Parameter will be set when the device wakes up" ) _LOGGER.info(msg, zwave_value, node, new_value) - raise_exceptions_from_results(nodes, results) + raise_exceptions_from_results(nodes_list, results) async def async_bulk_set_partial_config_parameters( self, service: ServiceCall @@ -450,7 +459,8 @@ class ZWaveServices: return_exceptions=True, ) - for node, cmd_status in get_valid_responses_from_results(nodes, results): + nodes_list = list(nodes) + for node, cmd_status in get_valid_responses_from_results(nodes_list, results): if cmd_status == CommandStatus.ACCEPTED: msg = "Bulk set partials for configuration parameter %s on Node %s" else: @@ -461,7 +471,7 @@ class ZWaveServices: _LOGGER.info(msg, property_, node) - raise_exceptions_from_results(nodes, results) + raise_exceptions_from_results(nodes_list, results) async def async_poll_value(self, service: ServiceCall) -> None: """Poll value on a node.""" @@ -477,10 +487,10 @@ class ZWaveServices: async def async_set_value(self, service: ServiceCall) -> None: """Set a value on a node.""" nodes: set[ZwaveNode] = service.data[const.ATTR_NODES] - command_class = service.data[const.ATTR_COMMAND_CLASS] - property_ = service.data[const.ATTR_PROPERTY] - property_key = service.data.get(const.ATTR_PROPERTY_KEY) - endpoint = service.data.get(const.ATTR_ENDPOINT) + command_class: CommandClass = service.data[const.ATTR_COMMAND_CLASS] + property_: int | str = service.data[const.ATTR_PROPERTY] + property_key: int | str | None = service.data.get(const.ATTR_PROPERTY_KEY) + endpoint: int | None = service.data.get(const.ATTR_ENDPOINT) new_value = service.data[const.ATTR_VALUE] wait_for_result = service.data.get(const.ATTR_WAIT_FOR_RESULT) options = service.data.get(const.ATTR_OPTIONS) @@ -515,17 +525,18 @@ class ZWaveServices: ) results = await asyncio.gather(*coros, return_exceptions=True) + nodes_list = list(nodes) # multiple set_values my fail so we will track the entire list - set_value_failed_nodes_list = [] - for node, success in get_valid_responses_from_results(nodes, results): + set_value_failed_nodes_list: list[ZwaveNode | Endpoint] = [] + for node_, success in get_valid_responses_from_results(nodes_list, results): if success is False: # If we failed to set a value, add node to SetValueFailed exception list - set_value_failed_nodes_list.append(node) + set_value_failed_nodes_list.append(node_) # Add the SetValueFailed exception to the results and the nodes to the node # list. No-op if there are no SetValueFailed exceptions raise_exceptions_from_results( - (*nodes, *set_value_failed_nodes_list), + (*nodes_list, *set_value_failed_nodes_list), (*results, *([SET_VALUE_FAILED_EXC] * len(set_value_failed_nodes_list))), ) @@ -543,17 +554,17 @@ class ZWaveServices: await self.async_set_value(service) return - command_class = service.data[const.ATTR_COMMAND_CLASS] - property_ = service.data[const.ATTR_PROPERTY] - property_key = service.data.get(const.ATTR_PROPERTY_KEY) - endpoint = service.data.get(const.ATTR_ENDPOINT) + command_class: CommandClass = service.data[const.ATTR_COMMAND_CLASS] + property_: int | str = service.data[const.ATTR_PROPERTY] + property_key: int | str | None = service.data.get(const.ATTR_PROPERTY_KEY) + endpoint: int | None = service.data.get(const.ATTR_ENDPOINT) + + value = ValueDataType(commandClass=command_class, property=property_) + if property_key is not None: + value["propertyKey"] = property_key + if endpoint is not None: + value["endpoint"] = endpoint - value = { - "commandClass": command_class, - "property": property_, - "propertyKey": property_key, - "endpoint": endpoint, - } new_value = service.data[const.ATTR_VALUE] # If there are no nodes, we can assume there is only one config entry due to @@ -590,7 +601,7 @@ class ZWaveServices: success = await async_multicast_set_value( client=client, new_value=new_value, - value_data={k: v for k, v in value.items() if v is not None}, + value_data=value, nodes=None if broadcast else list(nodes), options=options, ) @@ -627,8 +638,9 @@ class ZWaveServices: ), return_exceptions=True, ) + endpoints_list = list(endpoints) for endpoint, result in get_valid_responses_from_results( - endpoints, results + endpoints_list, results ): _LOGGER.info( ( @@ -640,7 +652,7 @@ class ZWaveServices: endpoint, result, ) - raise_exceptions_from_results(endpoints, results) + raise_exceptions_from_results(endpoints_list, results) # If an endpoint is provided, we assume the user wants to call the CC API on # that endpoint for all target nodes