Fix more typing for zwave_js (#72472)

* Fix more typing for zwave_js

* Revert one change

* reduce lines

* Fix tests
This commit is contained in:
Raman Gupta 2022-05-25 01:50:25 -04:00 committed by GitHub
parent 209f37196e
commit eb7a521232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 97 additions and 67 deletions

View File

@ -146,7 +146,7 @@ async def async_get_actions(
) -> list[dict[str, Any]]:
"""List device actions for Z-Wave JS devices."""
registry = entity_registry.async_get(hass)
actions = []
actions: list[dict] = []
node = async_get_node_from_device_id(hass, device_id)
@ -207,10 +207,13 @@ async def async_get_actions(
# If the value has the meterType CC specific value, we can add a reset_meter
# action for it
if CC_SPECIFIC_METER_TYPE in value.metadata.cc_specific:
meter_endpoints[value.endpoint].setdefault(
endpoint_idx = value.endpoint
if endpoint_idx is None:
endpoint_idx = 0
meter_endpoints[endpoint_idx].setdefault(
CONF_ENTITY_ID, entry.entity_id
)
meter_endpoints[value.endpoint].setdefault(ATTR_METER_TYPE, set()).add(
meter_endpoints[endpoint_idx].setdefault(ATTR_METER_TYPE, set()).add(
get_meter_type(value)
)

View File

@ -44,6 +44,8 @@ def generate_config_parameter_subtype(config_value: ConfigurationValue) -> str:
"""Generate the config parameter name used in a device automation subtype."""
parameter = str(config_value.property_)
if config_value.property_key:
# Property keys for config values are always an int
assert isinstance(config_value.property_key, int)
parameter = f"{parameter}[{hex(config_value.property_key)}]"
return f"{parameter} ({config_value.property_name})"

View File

@ -254,7 +254,7 @@ async def async_get_triggers(
dev_reg = device_registry.async_get(hass)
node = async_get_node_from_device_id(hass, device_id, dev_reg)
triggers = []
triggers: list[dict] = []
base_trigger = {
CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id,

View File

@ -1,7 +1,8 @@
"""Provides diagnostics for Z-Wave JS."""
from __future__ import annotations
from dataclasses import astuple
from copy import deepcopy
from dataclasses import astuple, dataclass
from typing import Any
from zwave_js_server.client import Client
@ -20,25 +21,43 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import DATA_CLIENT, DOMAIN
from .helpers import (
ZwaveValueID,
get_home_and_node_id_from_device_entry,
get_state_key_from_unique_id,
get_value_id_from_unique_id,
)
@dataclass
class ZwaveValueMatcher:
"""Class to allow matching a Z-Wave Value."""
property_: str | int | None = None
command_class: int | None = None
endpoint: int | None = None
property_key: str | int | None = None
def __post_init__(self) -> None:
"""Post initialization check."""
if all(val is None for val in astuple(self)):
raise ValueError("At least one of the fields must be set.")
KEYS_TO_REDACT = {"homeId", "location"}
VALUES_TO_REDACT = (
ZwaveValueID(property_="userCode", command_class=CommandClass.USER_CODE),
ZwaveValueMatcher(property_="userCode", command_class=CommandClass.USER_CODE),
)
def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType:
"""Redact value of a Z-Wave value."""
for value_to_redact in VALUES_TO_REDACT:
zwave_value_id = ZwaveValueID(
property_=zwave_value["property"],
command_class=CommandClass(zwave_value["commandClass"]),
command_class = None
if "commandClass" in zwave_value:
command_class = CommandClass(zwave_value["commandClass"])
zwave_value_id = ZwaveValueMatcher(
property_=zwave_value.get("property"),
command_class=command_class,
endpoint=zwave_value.get("endpoint"),
property_key=zwave_value.get("propertyKey"),
)
@ -48,19 +67,19 @@ def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType:
astuple(value_to_redact), astuple(zwave_value_id)
)
):
return {**zwave_value, "value": REDACTED}
redacted_value: ValueDataType = deepcopy(zwave_value)
redacted_value["value"] = REDACTED
return redacted_value
return zwave_value
def redact_node_state(node_state: NodeDataType) -> NodeDataType:
"""Redact node state."""
return {
**node_state,
"values": [
redact_value_of_zwave_value(zwave_value)
for zwave_value in node_state["values"]
],
}
redacted_state: NodeDataType = deepcopy(node_state)
redacted_state["values"] = [
redact_value_of_zwave_value(zwave_value) for zwave_value in node_state["values"]
]
return redacted_state
def get_device_entities(
@ -125,15 +144,17 @@ async def async_get_config_entry_diagnostics(
async def async_get_device_diagnostics(
hass: HomeAssistant, config_entry: ConfigEntry, device: dr.DeviceEntry
) -> NodeDataType:
) -> dict:
"""Return diagnostics for a device."""
client: Client = hass.data[DOMAIN][config_entry.entry_id][DATA_CLIENT]
identifiers = get_home_and_node_id_from_device_entry(device)
node_id = identifiers[1] if identifiers else None
if node_id is None or node_id not in client.driver.controller.nodes:
assert (driver := client.driver)
if node_id is None or node_id not in driver.controller.nodes:
raise ValueError(f"Node for device {device.id} can't be found")
node = client.driver.controller.nodes[node_id]
node = driver.controller.nodes[node_id]
entities = get_device_entities(hass, node, device)
assert client.version
return {
"versionInfo": {
"driverVersion": client.version.driver_version,

View File

@ -406,7 +406,7 @@ class TiltValueMix:
class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix):
"""Tilt data template class for Z-Wave Cover entities."""
def resolve_data(self, value: ZwaveValue) -> dict[str, Any]:
def resolve_data(self, value: ZwaveValue) -> dict[str, ZwaveValue | None]:
"""Resolve helper class data for a discovered value."""
return {"tilt_value": self._get_value_from_id(value.node, self.tilt_value_id)}
@ -415,7 +415,9 @@ class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix):
return [resolved_data["tilt_value"]]
@staticmethod
def current_tilt_value(resolved_data: dict[str, Any]) -> ZwaveValue | None:
def current_tilt_value(
resolved_data: dict[str, ZwaveValue | None]
) -> ZwaveValue | None:
"""Get current tilt ZwaveValue from resolved data."""
return resolved_data["tilt_value"]

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import astuple, dataclass
from dataclasses import dataclass
import logging
from typing import Any, cast
@ -48,16 +48,11 @@ from .const import (
class ZwaveValueID:
"""Class to represent a value ID."""
property_: str | int | None = None
command_class: int | None = None
property_: str | int
command_class: int
endpoint: int | None = None
property_key: str | int | None = None
def __post_init__(self) -> None:
"""Post initialization check."""
if all(val is None for val in astuple(self)):
raise ValueError("At least one of the fields must be set.")
@callback
def get_value_id_from_unique_id(unique_id: str) -> str | None:

View File

@ -400,7 +400,7 @@ class ZWaveServices:
async def async_set_config_parameter(self, service: ServiceCall) -> None:
"""Set a config value on a node."""
nodes = service.data[const.ATTR_NODES]
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
property_or_property_name = service.data[const.ATTR_CONFIG_PARAMETER]
property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK)
new_value = service.data[const.ATTR_CONFIG_VALUE]
@ -434,7 +434,7 @@ class ZWaveServices:
self, service: ServiceCall
) -> None:
"""Bulk set multiple partial config values on a node."""
nodes = service.data[const.ATTR_NODES]
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
property_ = service.data[const.ATTR_CONFIG_PARAMETER]
new_value = service.data[const.ATTR_CONFIG_VALUE]
@ -531,7 +531,7 @@ class ZWaveServices:
async def async_multicast_set_value(self, service: ServiceCall) -> None:
"""Set a value via multicast to multiple nodes."""
nodes = service.data[const.ATTR_NODES]
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
broadcast: bool = service.data[const.ATTR_BROADCAST]
options = service.data.get(const.ATTR_OPTIONS)
@ -559,13 +559,15 @@ class ZWaveServices:
# If there are no nodes, we can assume there is only one config entry due to
# schema validation and can use that to get the client, otherwise we can just
# get the client from the node.
client: ZwaveClient = None
first_node: ZwaveNode = next((node for node in nodes), None)
if first_node:
client: ZwaveClient
first_node: ZwaveNode
try:
first_node = next(node for node in nodes)
client = first_node.client
else:
except StopIteration:
entry_id = self._hass.config_entries.async_entries(const.DOMAIN)[0].entry_id
client = self._hass.data[const.DOMAIN][entry_id][const.DATA_CLIENT]
assert client.driver
first_node = next(
node
for node in client.driver.controller.nodes.values()
@ -688,6 +690,9 @@ class ZWaveServices:
_LOGGER.warning("Skipping entity %s as it has no value ID", entity_id)
continue
endpoints.add(node.endpoints[node.values[value_id].endpoint])
endpoint_idx = node.values[value_id].endpoint
endpoints.add(
node.endpoints[endpoint_idx if endpoint_idx is not None else 0]
)
await _async_invoke_cc_api(endpoints)

View File

@ -39,12 +39,6 @@ from homeassistant.helpers.typing import ConfigType
# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
EVENT_MODEL_MAP = {
"controller": CONTROLLER_EVENT_MODEL_MAP,
"driver": DRIVER_EVENT_MODEL_MAP,
"node": NODE_EVENT_MODEL_MAP,
}
def validate_non_node_event_source(obj: dict) -> dict:
"""Validate that a trigger for a non node event source has a config entry."""
@ -58,7 +52,12 @@ def validate_event_name(obj: dict) -> dict:
event_source = obj[ATTR_EVENT_SOURCE]
event_name = obj[ATTR_EVENT]
# the keys to the event source's model map are the event names
vol.In(EVENT_MODEL_MAP[event_source])(event_name)
if event_source == "controller":
vol.In(CONTROLLER_EVENT_MODEL_MAP)(event_name)
elif event_source == "driver":
vol.In(DRIVER_EVENT_MODEL_MAP)(event_name)
else:
vol.In(NODE_EVENT_MODEL_MAP)(event_name)
return obj
@ -68,11 +67,16 @@ def validate_event_data(obj: dict) -> dict:
if ATTR_EVENT_DATA not in obj:
return obj
event_source = obj[ATTR_EVENT_SOURCE]
event_name = obj[ATTR_EVENT]
event_data = obj[ATTR_EVENT_DATA]
event_source: str = obj[ATTR_EVENT_SOURCE]
event_name: str = obj[ATTR_EVENT]
event_data: dict = obj[ATTR_EVENT_DATA]
try:
EVENT_MODEL_MAP[event_source][event_name](**event_data)
if event_source == "controller":
CONTROLLER_EVENT_MODEL_MAP[event_name](**event_data)
elif event_source == "driver":
DRIVER_EVENT_MODEL_MAP[event_name](**event_data)
else:
NODE_EVENT_MODEL_MAP[event_name](**event_data)
except ValidationError as exc:
# Filter out required field errors if keys can be missing, and if there are
# still errors, raise an exception
@ -90,7 +94,7 @@ TRIGGER_SCHEMA = vol.All(
vol.Optional(ATTR_CONFIG_ENTRY_ID): str,
vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Required(ATTR_EVENT_SOURCE): vol.In(EVENT_MODEL_MAP),
vol.Required(ATTR_EVENT_SOURCE): vol.In(["controller", "driver", "node"]),
vol.Required(ATTR_EVENT): cv.string,
vol.Optional(ATTR_EVENT_DATA): dict,
vol.Optional(ATTR_PARTIAL_DICT_MATCH, default=False): bool,
@ -200,11 +204,11 @@ async def async_attach_trigger(
if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
assert client.driver
if event_source == "controller":
source = client.driver.controller
unsubs.append(client.driver.controller.on(event_name, async_on_event))
else:
source = client.driver
unsubs.append(source.on(event_name, async_on_event))
unsubs.append(client.driver.on(event_name, async_on_event))
for node in nodes:
driver = node.client.driver

View File

@ -5,7 +5,6 @@ import functools
import voluptuous as vol
from zwave_js_server.const import CommandClass
from zwave_js_server.event import Event
from zwave_js_server.model.node import Node
from zwave_js_server.model.value import Value, get_value_id
@ -108,7 +107,7 @@ async def async_attach_trigger(
@callback
def async_on_value_updated(
value: Value, device: dr.DeviceEntry, event: Event
value: Value, device: dr.DeviceEntry, event: dict
) -> None:
"""Handle value update."""
event_value: Value = event["value"]

View File

@ -5,7 +5,10 @@ import pytest
from zwave_js_server.event import Event
from homeassistant.components.diagnostics.const import REDACTED
from homeassistant.components.zwave_js.diagnostics import async_get_device_diagnostics
from homeassistant.components.zwave_js.diagnostics import (
ZwaveValueMatcher,
async_get_device_diagnostics,
)
from homeassistant.components.zwave_js.discovery import async_discover_node_values
from homeassistant.components.zwave_js.helpers import get_device_id
from homeassistant.helpers.device_registry import async_get
@ -100,3 +103,9 @@ async def test_device_diagnostics_error(hass, integration):
)
with pytest.raises(ValueError):
await async_get_device_diagnostics(hass, integration, device)
async def test_empty_zwave_value_matcher():
"""Test empty ZwaveValueMatcher is invalid."""
with pytest.raises(ValueError):
ZwaveValueMatcher()

View File

@ -1,10 +0,0 @@
"""Test Z-Wave JS helpers module."""
import pytest
from homeassistant.components.zwave_js.helpers import ZwaveValueID
async def test_empty_zwave_value_id():
"""Test empty ZwaveValueID is invalid."""
with pytest.raises(ValueError):
ZwaveValueID()