From 9cbcf5f2a5ad40656f44d77491236d6229a53266 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 21 May 2024 07:42:07 +0200 Subject: [PATCH] Improve zwave_js TypeVar usage (#117810) * Improve zwave_js TypeVar usage * Use underscore for TypeVar name --- .../zwave_js/discovery_data_template.py | 17 ++++------------ homeassistant/components/zwave_js/services.py | 20 +++++++++---------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/zwave_js/discovery_data_template.py b/homeassistant/components/zwave_js/discovery_data_template.py index 7eb85e0ea4d..e619c6afc7c 100644 --- a/homeassistant/components/zwave_js/discovery_data_template.py +++ b/homeassistant/components/zwave_js/discovery_data_template.py @@ -4,8 +4,9 @@ from __future__ import annotations from collections.abc import Iterable, Mapping from dataclasses import dataclass, field +from enum import Enum import logging -from typing import Any, TypeVar, cast +from typing import Any, cast from zwave_js_server.const import CommandClass from zwave_js_server.const.command_class.energy_production import ( @@ -357,22 +358,12 @@ class NumericSensorDataTemplateData: unit_of_measurement: str | None = None -T = TypeVar( - "T", - MultilevelSensorType, - MultilevelSensorScaleType, - MeterScaleType, - EnergyProductionParameter, - EnergyProductionScaleType, -) - - class NumericSensorDataTemplate(BaseDiscoverySchemaDataTemplate): """Data template class for Z-Wave Sensor entities.""" @staticmethod - def find_key_from_matching_set( - enum_value: T, set_map: Mapping[str, list[T]] + def find_key_from_matching_set[_T: Enum]( + enum_value: _T, set_map: Mapping[str, list[_T]] ) -> str | None: """Find a key in a set map that matches a given enum value.""" for key, value_set in set_map.items(): diff --git a/homeassistant/components/zwave_js/services.py b/homeassistant/components/zwave_js/services.py index a25095156ed..ba78777fa51 100644 --- a/homeassistant/components/zwave_js/services.py +++ b/homeassistant/components/zwave_js/services.py @@ -3,10 +3,10 @@ from __future__ import annotations import asyncio -from collections.abc import Generator, Sequence +from collections.abc import Collection, Generator, Sequence import logging import math -from typing import Any, TypeVar +from typing import Any import voluptuous as vol from zwave_js_server.client import Client as ZwaveClient @@ -46,7 +46,7 @@ from .helpers import ( _LOGGER = logging.getLogger(__name__) -T = TypeVar("T", ZwaveNode, Endpoint) +type _NodeOrEndpointType = ZwaveNode | Endpoint def parameter_name_does_not_need_bitmask( @@ -81,9 +81,9 @@ def broadcast_command(val: dict[str, Any]) -> dict[str, Any]: ) -def get_valid_responses_from_results( - zwave_objects: Sequence[T], results: Sequence[Any] -) -> Generator[tuple[T, Any], None, None]: +def get_valid_responses_from_results[_T: ZwaveNode | Endpoint]( + zwave_objects: Sequence[_T], results: Sequence[Any] +) -> Generator[tuple[_T, Any], None, None]: """Return valid responses from a list of results.""" for zwave_object, result in zip(zwave_objects, results, strict=False): if not isinstance(result, Exception): @@ -91,10 +91,10 @@ def get_valid_responses_from_results( def raise_exceptions_from_results( - zwave_objects: Sequence[T], results: Sequence[Any] + zwave_objects: Sequence[_NodeOrEndpointType], results: Sequence[Any] ) -> None: """Raise list of exceptions from a list of results.""" - errors: Sequence[tuple[T, Any]] + errors: Sequence[tuple[_NodeOrEndpointType, Any]] if errors := [ tup for tup in zip(zwave_objects, results, strict=True) @@ -112,7 +112,7 @@ def raise_exceptions_from_results( async def _async_invoke_cc_api( - nodes_or_endpoints: set[T], + nodes_or_endpoints: Collection[_NodeOrEndpointType], command_class: CommandClass, method_name: str, *args: Any, @@ -561,7 +561,7 @@ class ZWaveServices: ) def process_results( - nodes_or_endpoints_list: list[T], _results: list[Any] + nodes_or_endpoints_list: Sequence[_NodeOrEndpointType], _results: list[Any] ) -> None: """Process results for given nodes or endpoints.""" for node_or_endpoint, result in get_valid_responses_from_results(