From eb7a521232791f6728079d5583dfad8b651ef1a1 Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Wed, 25 May 2022 01:50:25 -0400 Subject: [PATCH] Fix more typing for zwave_js (#72472) * Fix more typing for zwave_js * Revert one change * reduce lines * Fix tests --- .../components/zwave_js/device_action.py | 9 ++- .../zwave_js/device_automation_helpers.py | 2 + .../components/zwave_js/device_trigger.py | 2 +- .../components/zwave_js/diagnostics.py | 55 +++++++++++++------ .../zwave_js/discovery_data_template.py | 6 +- homeassistant/components/zwave_js/helpers.py | 11 +--- homeassistant/components/zwave_js/services.py | 21 ++++--- .../components/zwave_js/triggers/event.py | 34 +++++++----- .../zwave_js/triggers/value_updated.py | 3 +- tests/components/zwave_js/test_diagnostics.py | 11 +++- tests/components/zwave_js/test_helpers.py | 10 ---- 11 files changed, 97 insertions(+), 67 deletions(-) delete mode 100644 tests/components/zwave_js/test_helpers.py diff --git a/homeassistant/components/zwave_js/device_action.py b/homeassistant/components/zwave_js/device_action.py index f001accd196..2494cb8216d 100644 --- a/homeassistant/components/zwave_js/device_action.py +++ b/homeassistant/components/zwave_js/device_action.py @@ -146,7 +146,7 @@ async def async_get_actions( ) -> list[dict[str, Any]]: """List device actions for Z-Wave JS devices.""" registry = entity_registry.async_get(hass) - actions = [] + actions: list[dict] = [] node = async_get_node_from_device_id(hass, device_id) @@ -207,10 +207,13 @@ async def async_get_actions( # If the value has the meterType CC specific value, we can add a reset_meter # action for it if CC_SPECIFIC_METER_TYPE in value.metadata.cc_specific: - meter_endpoints[value.endpoint].setdefault( + endpoint_idx = value.endpoint + if endpoint_idx is None: + endpoint_idx = 0 + meter_endpoints[endpoint_idx].setdefault( CONF_ENTITY_ID, entry.entity_id ) - meter_endpoints[value.endpoint].setdefault(ATTR_METER_TYPE, set()).add( + meter_endpoints[endpoint_idx].setdefault(ATTR_METER_TYPE, set()).add( get_meter_type(value) ) diff --git a/homeassistant/components/zwave_js/device_automation_helpers.py b/homeassistant/components/zwave_js/device_automation_helpers.py index f17ddccf03c..25cce978df1 100644 --- a/homeassistant/components/zwave_js/device_automation_helpers.py +++ b/homeassistant/components/zwave_js/device_automation_helpers.py @@ -44,6 +44,8 @@ def generate_config_parameter_subtype(config_value: ConfigurationValue) -> str: """Generate the config parameter name used in a device automation subtype.""" parameter = str(config_value.property_) if config_value.property_key: + # Property keys for config values are always an int + assert isinstance(config_value.property_key, int) parameter = f"{parameter}[{hex(config_value.property_key)}]" return f"{parameter} ({config_value.property_name})" diff --git a/homeassistant/components/zwave_js/device_trigger.py b/homeassistant/components/zwave_js/device_trigger.py index 75a9647a5ca..ad860089d1d 100644 --- a/homeassistant/components/zwave_js/device_trigger.py +++ b/homeassistant/components/zwave_js/device_trigger.py @@ -254,7 +254,7 @@ async def async_get_triggers( dev_reg = device_registry.async_get(hass) node = async_get_node_from_device_id(hass, device_id, dev_reg) - triggers = [] + triggers: list[dict] = [] base_trigger = { CONF_PLATFORM: "device", CONF_DEVICE_ID: device_id, diff --git a/homeassistant/components/zwave_js/diagnostics.py b/homeassistant/components/zwave_js/diagnostics.py index dfb6661b5c0..4e1abe37b1b 100644 --- a/homeassistant/components/zwave_js/diagnostics.py +++ b/homeassistant/components/zwave_js/diagnostics.py @@ -1,7 +1,8 @@ """Provides diagnostics for Z-Wave JS.""" from __future__ import annotations -from dataclasses import astuple +from copy import deepcopy +from dataclasses import astuple, dataclass from typing import Any from zwave_js_server.client import Client @@ -20,25 +21,43 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import DATA_CLIENT, DOMAIN from .helpers import ( - ZwaveValueID, get_home_and_node_id_from_device_entry, get_state_key_from_unique_id, get_value_id_from_unique_id, ) + +@dataclass +class ZwaveValueMatcher: + """Class to allow matching a Z-Wave Value.""" + + property_: str | int | None = None + command_class: int | None = None + endpoint: int | None = None + property_key: str | int | None = None + + def __post_init__(self) -> None: + """Post initialization check.""" + if all(val is None for val in astuple(self)): + raise ValueError("At least one of the fields must be set.") + + KEYS_TO_REDACT = {"homeId", "location"} VALUES_TO_REDACT = ( - ZwaveValueID(property_="userCode", command_class=CommandClass.USER_CODE), + ZwaveValueMatcher(property_="userCode", command_class=CommandClass.USER_CODE), ) def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType: """Redact value of a Z-Wave value.""" for value_to_redact in VALUES_TO_REDACT: - zwave_value_id = ZwaveValueID( - property_=zwave_value["property"], - command_class=CommandClass(zwave_value["commandClass"]), + command_class = None + if "commandClass" in zwave_value: + command_class = CommandClass(zwave_value["commandClass"]) + zwave_value_id = ZwaveValueMatcher( + property_=zwave_value.get("property"), + command_class=command_class, endpoint=zwave_value.get("endpoint"), property_key=zwave_value.get("propertyKey"), ) @@ -48,19 +67,19 @@ def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType: astuple(value_to_redact), astuple(zwave_value_id) ) ): - return {**zwave_value, "value": REDACTED} + redacted_value: ValueDataType = deepcopy(zwave_value) + redacted_value["value"] = REDACTED + return redacted_value return zwave_value def redact_node_state(node_state: NodeDataType) -> NodeDataType: """Redact node state.""" - return { - **node_state, - "values": [ - redact_value_of_zwave_value(zwave_value) - for zwave_value in node_state["values"] - ], - } + redacted_state: NodeDataType = deepcopy(node_state) + redacted_state["values"] = [ + redact_value_of_zwave_value(zwave_value) for zwave_value in node_state["values"] + ] + return redacted_state def get_device_entities( @@ -125,15 +144,17 @@ async def async_get_config_entry_diagnostics( async def async_get_device_diagnostics( hass: HomeAssistant, config_entry: ConfigEntry, device: dr.DeviceEntry -) -> NodeDataType: +) -> dict: """Return diagnostics for a device.""" client: Client = hass.data[DOMAIN][config_entry.entry_id][DATA_CLIENT] identifiers = get_home_and_node_id_from_device_entry(device) node_id = identifiers[1] if identifiers else None - if node_id is None or node_id not in client.driver.controller.nodes: + assert (driver := client.driver) + if node_id is None or node_id not in driver.controller.nodes: raise ValueError(f"Node for device {device.id} can't be found") - node = client.driver.controller.nodes[node_id] + node = driver.controller.nodes[node_id] entities = get_device_entities(hass, node, device) + assert client.version return { "versionInfo": { "driverVersion": client.version.driver_version, diff --git a/homeassistant/components/zwave_js/discovery_data_template.py b/homeassistant/components/zwave_js/discovery_data_template.py index 9dc90a43f3d..512cfafa63a 100644 --- a/homeassistant/components/zwave_js/discovery_data_template.py +++ b/homeassistant/components/zwave_js/discovery_data_template.py @@ -406,7 +406,7 @@ class TiltValueMix: class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix): """Tilt data template class for Z-Wave Cover entities.""" - def resolve_data(self, value: ZwaveValue) -> dict[str, Any]: + def resolve_data(self, value: ZwaveValue) -> dict[str, ZwaveValue | None]: """Resolve helper class data for a discovered value.""" return {"tilt_value": self._get_value_from_id(value.node, self.tilt_value_id)} @@ -415,7 +415,9 @@ class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix): return [resolved_data["tilt_value"]] @staticmethod - def current_tilt_value(resolved_data: dict[str, Any]) -> ZwaveValue | None: + def current_tilt_value( + resolved_data: dict[str, ZwaveValue | None] + ) -> ZwaveValue | None: """Get current tilt ZwaveValue from resolved data.""" return resolved_data["tilt_value"] diff --git a/homeassistant/components/zwave_js/helpers.py b/homeassistant/components/zwave_js/helpers.py index 807ae0287eb..c047a3a9903 100644 --- a/homeassistant/components/zwave_js/helpers.py +++ b/homeassistant/components/zwave_js/helpers.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable -from dataclasses import astuple, dataclass +from dataclasses import dataclass import logging from typing import Any, cast @@ -48,16 +48,11 @@ from .const import ( class ZwaveValueID: """Class to represent a value ID.""" - property_: str | int | None = None - command_class: int | None = None + property_: str | int + command_class: int endpoint: int | None = None property_key: str | int | None = None - def __post_init__(self) -> None: - """Post initialization check.""" - if all(val is None for val in astuple(self)): - raise ValueError("At least one of the fields must be set.") - @callback def get_value_id_from_unique_id(unique_id: str) -> str | None: diff --git a/homeassistant/components/zwave_js/services.py b/homeassistant/components/zwave_js/services.py index af8f3c4c4e7..3b56e0a073c 100644 --- a/homeassistant/components/zwave_js/services.py +++ b/homeassistant/components/zwave_js/services.py @@ -400,7 +400,7 @@ class ZWaveServices: async def async_set_config_parameter(self, service: ServiceCall) -> None: """Set a config value on a node.""" - nodes = service.data[const.ATTR_NODES] + nodes: set[ZwaveNode] = service.data[const.ATTR_NODES] property_or_property_name = service.data[const.ATTR_CONFIG_PARAMETER] property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK) new_value = service.data[const.ATTR_CONFIG_VALUE] @@ -434,7 +434,7 @@ class ZWaveServices: self, service: ServiceCall ) -> None: """Bulk set multiple partial config values on a node.""" - nodes = service.data[const.ATTR_NODES] + nodes: set[ZwaveNode] = service.data[const.ATTR_NODES] property_ = service.data[const.ATTR_CONFIG_PARAMETER] new_value = service.data[const.ATTR_CONFIG_VALUE] @@ -531,7 +531,7 @@ class ZWaveServices: async def async_multicast_set_value(self, service: ServiceCall) -> None: """Set a value via multicast to multiple nodes.""" - nodes = service.data[const.ATTR_NODES] + nodes: set[ZwaveNode] = service.data[const.ATTR_NODES] broadcast: bool = service.data[const.ATTR_BROADCAST] options = service.data.get(const.ATTR_OPTIONS) @@ -559,13 +559,15 @@ class ZWaveServices: # If there are no nodes, we can assume there is only one config entry due to # schema validation and can use that to get the client, otherwise we can just # get the client from the node. - client: ZwaveClient = None - first_node: ZwaveNode = next((node for node in nodes), None) - if first_node: + client: ZwaveClient + first_node: ZwaveNode + try: + first_node = next(node for node in nodes) client = first_node.client - else: + except StopIteration: entry_id = self._hass.config_entries.async_entries(const.DOMAIN)[0].entry_id client = self._hass.data[const.DOMAIN][entry_id][const.DATA_CLIENT] + assert client.driver first_node = next( node for node in client.driver.controller.nodes.values() @@ -688,6 +690,9 @@ class ZWaveServices: _LOGGER.warning("Skipping entity %s as it has no value ID", entity_id) continue - endpoints.add(node.endpoints[node.values[value_id].endpoint]) + endpoint_idx = node.values[value_id].endpoint + endpoints.add( + node.endpoints[endpoint_idx if endpoint_idx is not None else 0] + ) await _async_invoke_cc_api(endpoints) diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 83fd7570ab9..94afd1e9117 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -39,12 +39,6 @@ from homeassistant.helpers.typing import ConfigType # Platform type should be . PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}" -EVENT_MODEL_MAP = { - "controller": CONTROLLER_EVENT_MODEL_MAP, - "driver": DRIVER_EVENT_MODEL_MAP, - "node": NODE_EVENT_MODEL_MAP, -} - def validate_non_node_event_source(obj: dict) -> dict: """Validate that a trigger for a non node event source has a config entry.""" @@ -58,7 +52,12 @@ def validate_event_name(obj: dict) -> dict: event_source = obj[ATTR_EVENT_SOURCE] event_name = obj[ATTR_EVENT] # the keys to the event source's model map are the event names - vol.In(EVENT_MODEL_MAP[event_source])(event_name) + if event_source == "controller": + vol.In(CONTROLLER_EVENT_MODEL_MAP)(event_name) + elif event_source == "driver": + vol.In(DRIVER_EVENT_MODEL_MAP)(event_name) + else: + vol.In(NODE_EVENT_MODEL_MAP)(event_name) return obj @@ -68,11 +67,16 @@ def validate_event_data(obj: dict) -> dict: if ATTR_EVENT_DATA not in obj: return obj - event_source = obj[ATTR_EVENT_SOURCE] - event_name = obj[ATTR_EVENT] - event_data = obj[ATTR_EVENT_DATA] + event_source: str = obj[ATTR_EVENT_SOURCE] + event_name: str = obj[ATTR_EVENT] + event_data: dict = obj[ATTR_EVENT_DATA] try: - EVENT_MODEL_MAP[event_source][event_name](**event_data) + if event_source == "controller": + CONTROLLER_EVENT_MODEL_MAP[event_name](**event_data) + elif event_source == "driver": + DRIVER_EVENT_MODEL_MAP[event_name](**event_data) + else: + NODE_EVENT_MODEL_MAP[event_name](**event_data) except ValidationError as exc: # Filter out required field errors if keys can be missing, and if there are # still errors, raise an exception @@ -90,7 +94,7 @@ TRIGGER_SCHEMA = vol.All( vol.Optional(ATTR_CONFIG_ENTRY_ID): str, vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]), vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, - vol.Required(ATTR_EVENT_SOURCE): vol.In(EVENT_MODEL_MAP), + vol.Required(ATTR_EVENT_SOURCE): vol.In(["controller", "driver", "node"]), vol.Required(ATTR_EVENT): cv.string, vol.Optional(ATTR_EVENT_DATA): dict, vol.Optional(ATTR_PARTIAL_DICT_MATCH, default=False): bool, @@ -200,11 +204,11 @@ async def async_attach_trigger( if not nodes: entry_id = config[ATTR_CONFIG_ENTRY_ID] client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT] + assert client.driver if event_source == "controller": - source = client.driver.controller + unsubs.append(client.driver.controller.on(event_name, async_on_event)) else: - source = client.driver - unsubs.append(source.on(event_name, async_on_event)) + unsubs.append(client.driver.on(event_name, async_on_event)) for node in nodes: driver = node.client.driver diff --git a/homeassistant/components/zwave_js/triggers/value_updated.py b/homeassistant/components/zwave_js/triggers/value_updated.py index 38a19eaa377..ac3aae1efed 100644 --- a/homeassistant/components/zwave_js/triggers/value_updated.py +++ b/homeassistant/components/zwave_js/triggers/value_updated.py @@ -5,7 +5,6 @@ import functools import voluptuous as vol from zwave_js_server.const import CommandClass -from zwave_js_server.event import Event from zwave_js_server.model.node import Node from zwave_js_server.model.value import Value, get_value_id @@ -108,7 +107,7 @@ async def async_attach_trigger( @callback def async_on_value_updated( - value: Value, device: dr.DeviceEntry, event: Event + value: Value, device: dr.DeviceEntry, event: dict ) -> None: """Handle value update.""" event_value: Value = event["value"] diff --git a/tests/components/zwave_js/test_diagnostics.py b/tests/components/zwave_js/test_diagnostics.py index 8d00c9a2f64..3fe3bdfeb89 100644 --- a/tests/components/zwave_js/test_diagnostics.py +++ b/tests/components/zwave_js/test_diagnostics.py @@ -5,7 +5,10 @@ import pytest from zwave_js_server.event import Event from homeassistant.components.diagnostics.const import REDACTED -from homeassistant.components.zwave_js.diagnostics import async_get_device_diagnostics +from homeassistant.components.zwave_js.diagnostics import ( + ZwaveValueMatcher, + async_get_device_diagnostics, +) from homeassistant.components.zwave_js.discovery import async_discover_node_values from homeassistant.components.zwave_js.helpers import get_device_id from homeassistant.helpers.device_registry import async_get @@ -100,3 +103,9 @@ async def test_device_diagnostics_error(hass, integration): ) with pytest.raises(ValueError): await async_get_device_diagnostics(hass, integration, device) + + +async def test_empty_zwave_value_matcher(): + """Test empty ZwaveValueMatcher is invalid.""" + with pytest.raises(ValueError): + ZwaveValueMatcher() diff --git a/tests/components/zwave_js/test_helpers.py b/tests/components/zwave_js/test_helpers.py deleted file mode 100644 index 290c93fa084..00000000000 --- a/tests/components/zwave_js/test_helpers.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Test Z-Wave JS helpers module.""" -import pytest - -from homeassistant.components.zwave_js.helpers import ZwaveValueID - - -async def test_empty_zwave_value_id(): - """Test empty ZwaveValueID is invalid.""" - with pytest.raises(ValueError): - ZwaveValueID()