Fix more typing for zwave_js (#72472)

* Fix more typing for zwave_js

* Revert one change

* reduce lines

* Fix tests
This commit is contained in:
Raman Gupta 2022-05-25 01:50:25 -04:00 committed by GitHub
parent 209f37196e
commit eb7a521232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 97 additions and 67 deletions

View File

@ -146,7 +146,7 @@ async def async_get_actions(
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""List device actions for Z-Wave JS devices.""" """List device actions for Z-Wave JS devices."""
registry = entity_registry.async_get(hass) registry = entity_registry.async_get(hass)
actions = [] actions: list[dict] = []
node = async_get_node_from_device_id(hass, device_id) node = async_get_node_from_device_id(hass, device_id)
@ -207,10 +207,13 @@ async def async_get_actions(
# If the value has the meterType CC specific value, we can add a reset_meter # If the value has the meterType CC specific value, we can add a reset_meter
# action for it # action for it
if CC_SPECIFIC_METER_TYPE in value.metadata.cc_specific: if CC_SPECIFIC_METER_TYPE in value.metadata.cc_specific:
meter_endpoints[value.endpoint].setdefault( endpoint_idx = value.endpoint
if endpoint_idx is None:
endpoint_idx = 0
meter_endpoints[endpoint_idx].setdefault(
CONF_ENTITY_ID, entry.entity_id CONF_ENTITY_ID, entry.entity_id
) )
meter_endpoints[value.endpoint].setdefault(ATTR_METER_TYPE, set()).add( meter_endpoints[endpoint_idx].setdefault(ATTR_METER_TYPE, set()).add(
get_meter_type(value) get_meter_type(value)
) )

View File

@ -44,6 +44,8 @@ def generate_config_parameter_subtype(config_value: ConfigurationValue) -> str:
"""Generate the config parameter name used in a device automation subtype.""" """Generate the config parameter name used in a device automation subtype."""
parameter = str(config_value.property_) parameter = str(config_value.property_)
if config_value.property_key: if config_value.property_key:
# Property keys for config values are always an int
assert isinstance(config_value.property_key, int)
parameter = f"{parameter}[{hex(config_value.property_key)}]" parameter = f"{parameter}[{hex(config_value.property_key)}]"
return f"{parameter} ({config_value.property_name})" return f"{parameter} ({config_value.property_name})"

View File

@ -254,7 +254,7 @@ async def async_get_triggers(
dev_reg = device_registry.async_get(hass) dev_reg = device_registry.async_get(hass)
node = async_get_node_from_device_id(hass, device_id, dev_reg) node = async_get_node_from_device_id(hass, device_id, dev_reg)
triggers = [] triggers: list[dict] = []
base_trigger = { base_trigger = {
CONF_PLATFORM: "device", CONF_PLATFORM: "device",
CONF_DEVICE_ID: device_id, CONF_DEVICE_ID: device_id,

View File

@ -1,7 +1,8 @@
"""Provides diagnostics for Z-Wave JS.""" """Provides diagnostics for Z-Wave JS."""
from __future__ import annotations from __future__ import annotations
from dataclasses import astuple from copy import deepcopy
from dataclasses import astuple, dataclass
from typing import Any from typing import Any
from zwave_js_server.client import Client from zwave_js_server.client import Client
@ -20,25 +21,43 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import DATA_CLIENT, DOMAIN from .const import DATA_CLIENT, DOMAIN
from .helpers import ( from .helpers import (
ZwaveValueID,
get_home_and_node_id_from_device_entry, get_home_and_node_id_from_device_entry,
get_state_key_from_unique_id, get_state_key_from_unique_id,
get_value_id_from_unique_id, get_value_id_from_unique_id,
) )
@dataclass
class ZwaveValueMatcher:
"""Class to allow matching a Z-Wave Value."""
property_: str | int | None = None
command_class: int | None = None
endpoint: int | None = None
property_key: str | int | None = None
def __post_init__(self) -> None:
"""Post initialization check."""
if all(val is None for val in astuple(self)):
raise ValueError("At least one of the fields must be set.")
KEYS_TO_REDACT = {"homeId", "location"} KEYS_TO_REDACT = {"homeId", "location"}
VALUES_TO_REDACT = ( VALUES_TO_REDACT = (
ZwaveValueID(property_="userCode", command_class=CommandClass.USER_CODE), ZwaveValueMatcher(property_="userCode", command_class=CommandClass.USER_CODE),
) )
def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType: def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType:
"""Redact value of a Z-Wave value.""" """Redact value of a Z-Wave value."""
for value_to_redact in VALUES_TO_REDACT: for value_to_redact in VALUES_TO_REDACT:
zwave_value_id = ZwaveValueID( command_class = None
property_=zwave_value["property"], if "commandClass" in zwave_value:
command_class=CommandClass(zwave_value["commandClass"]), command_class = CommandClass(zwave_value["commandClass"])
zwave_value_id = ZwaveValueMatcher(
property_=zwave_value.get("property"),
command_class=command_class,
endpoint=zwave_value.get("endpoint"), endpoint=zwave_value.get("endpoint"),
property_key=zwave_value.get("propertyKey"), property_key=zwave_value.get("propertyKey"),
) )
@ -48,19 +67,19 @@ def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType:
astuple(value_to_redact), astuple(zwave_value_id) astuple(value_to_redact), astuple(zwave_value_id)
) )
): ):
return {**zwave_value, "value": REDACTED} redacted_value: ValueDataType = deepcopy(zwave_value)
redacted_value["value"] = REDACTED
return redacted_value
return zwave_value return zwave_value
def redact_node_state(node_state: NodeDataType) -> NodeDataType: def redact_node_state(node_state: NodeDataType) -> NodeDataType:
"""Redact node state.""" """Redact node state."""
return { redacted_state: NodeDataType = deepcopy(node_state)
**node_state, redacted_state["values"] = [
"values": [ redact_value_of_zwave_value(zwave_value) for zwave_value in node_state["values"]
redact_value_of_zwave_value(zwave_value) ]
for zwave_value in node_state["values"] return redacted_state
],
}
def get_device_entities( def get_device_entities(
@ -125,15 +144,17 @@ async def async_get_config_entry_diagnostics(
async def async_get_device_diagnostics( async def async_get_device_diagnostics(
hass: HomeAssistant, config_entry: ConfigEntry, device: dr.DeviceEntry hass: HomeAssistant, config_entry: ConfigEntry, device: dr.DeviceEntry
) -> NodeDataType: ) -> dict:
"""Return diagnostics for a device.""" """Return diagnostics for a device."""
client: Client = hass.data[DOMAIN][config_entry.entry_id][DATA_CLIENT] client: Client = hass.data[DOMAIN][config_entry.entry_id][DATA_CLIENT]
identifiers = get_home_and_node_id_from_device_entry(device) identifiers = get_home_and_node_id_from_device_entry(device)
node_id = identifiers[1] if identifiers else None node_id = identifiers[1] if identifiers else None
if node_id is None or node_id not in client.driver.controller.nodes: assert (driver := client.driver)
if node_id is None or node_id not in driver.controller.nodes:
raise ValueError(f"Node for device {device.id} can't be found") raise ValueError(f"Node for device {device.id} can't be found")
node = client.driver.controller.nodes[node_id] node = driver.controller.nodes[node_id]
entities = get_device_entities(hass, node, device) entities = get_device_entities(hass, node, device)
assert client.version
return { return {
"versionInfo": { "versionInfo": {
"driverVersion": client.version.driver_version, "driverVersion": client.version.driver_version,

View File

@ -406,7 +406,7 @@ class TiltValueMix:
class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix): class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix):
"""Tilt data template class for Z-Wave Cover entities.""" """Tilt data template class for Z-Wave Cover entities."""
def resolve_data(self, value: ZwaveValue) -> dict[str, Any]: def resolve_data(self, value: ZwaveValue) -> dict[str, ZwaveValue | None]:
"""Resolve helper class data for a discovered value.""" """Resolve helper class data for a discovered value."""
return {"tilt_value": self._get_value_from_id(value.node, self.tilt_value_id)} return {"tilt_value": self._get_value_from_id(value.node, self.tilt_value_id)}
@ -415,7 +415,9 @@ class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix):
return [resolved_data["tilt_value"]] return [resolved_data["tilt_value"]]
@staticmethod @staticmethod
def current_tilt_value(resolved_data: dict[str, Any]) -> ZwaveValue | None: def current_tilt_value(
resolved_data: dict[str, ZwaveValue | None]
) -> ZwaveValue | None:
"""Get current tilt ZwaveValue from resolved data.""" """Get current tilt ZwaveValue from resolved data."""
return resolved_data["tilt_value"] return resolved_data["tilt_value"]

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from dataclasses import astuple, dataclass from dataclasses import dataclass
import logging import logging
from typing import Any, cast from typing import Any, cast
@ -48,16 +48,11 @@ from .const import (
class ZwaveValueID: class ZwaveValueID:
"""Class to represent a value ID.""" """Class to represent a value ID."""
property_: str | int | None = None property_: str | int
command_class: int | None = None command_class: int
endpoint: int | None = None endpoint: int | None = None
property_key: str | int | None = None property_key: str | int | None = None
def __post_init__(self) -> None:
"""Post initialization check."""
if all(val is None for val in astuple(self)):
raise ValueError("At least one of the fields must be set.")
@callback @callback
def get_value_id_from_unique_id(unique_id: str) -> str | None: def get_value_id_from_unique_id(unique_id: str) -> str | None:

View File

@ -400,7 +400,7 @@ class ZWaveServices:
async def async_set_config_parameter(self, service: ServiceCall) -> None: async def async_set_config_parameter(self, service: ServiceCall) -> None:
"""Set a config value on a node.""" """Set a config value on a node."""
nodes = service.data[const.ATTR_NODES] nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
property_or_property_name = service.data[const.ATTR_CONFIG_PARAMETER] property_or_property_name = service.data[const.ATTR_CONFIG_PARAMETER]
property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK) property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK)
new_value = service.data[const.ATTR_CONFIG_VALUE] new_value = service.data[const.ATTR_CONFIG_VALUE]
@ -434,7 +434,7 @@ class ZWaveServices:
self, service: ServiceCall self, service: ServiceCall
) -> None: ) -> None:
"""Bulk set multiple partial config values on a node.""" """Bulk set multiple partial config values on a node."""
nodes = service.data[const.ATTR_NODES] nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
property_ = service.data[const.ATTR_CONFIG_PARAMETER] property_ = service.data[const.ATTR_CONFIG_PARAMETER]
new_value = service.data[const.ATTR_CONFIG_VALUE] new_value = service.data[const.ATTR_CONFIG_VALUE]
@ -531,7 +531,7 @@ class ZWaveServices:
async def async_multicast_set_value(self, service: ServiceCall) -> None: async def async_multicast_set_value(self, service: ServiceCall) -> None:
"""Set a value via multicast to multiple nodes.""" """Set a value via multicast to multiple nodes."""
nodes = service.data[const.ATTR_NODES] nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
broadcast: bool = service.data[const.ATTR_BROADCAST] broadcast: bool = service.data[const.ATTR_BROADCAST]
options = service.data.get(const.ATTR_OPTIONS) options = service.data.get(const.ATTR_OPTIONS)
@ -559,13 +559,15 @@ class ZWaveServices:
# If there are no nodes, we can assume there is only one config entry due to # If there are no nodes, we can assume there is only one config entry due to
# schema validation and can use that to get the client, otherwise we can just # schema validation and can use that to get the client, otherwise we can just
# get the client from the node. # get the client from the node.
client: ZwaveClient = None client: ZwaveClient
first_node: ZwaveNode = next((node for node in nodes), None) first_node: ZwaveNode
if first_node: try:
first_node = next(node for node in nodes)
client = first_node.client client = first_node.client
else: except StopIteration:
entry_id = self._hass.config_entries.async_entries(const.DOMAIN)[0].entry_id entry_id = self._hass.config_entries.async_entries(const.DOMAIN)[0].entry_id
client = self._hass.data[const.DOMAIN][entry_id][const.DATA_CLIENT] client = self._hass.data[const.DOMAIN][entry_id][const.DATA_CLIENT]
assert client.driver
first_node = next( first_node = next(
node node
for node in client.driver.controller.nodes.values() for node in client.driver.controller.nodes.values()
@ -688,6 +690,9 @@ class ZWaveServices:
_LOGGER.warning("Skipping entity %s as it has no value ID", entity_id) _LOGGER.warning("Skipping entity %s as it has no value ID", entity_id)
continue continue
endpoints.add(node.endpoints[node.values[value_id].endpoint]) endpoint_idx = node.values[value_id].endpoint
endpoints.add(
node.endpoints[endpoint_idx if endpoint_idx is not None else 0]
)
await _async_invoke_cc_api(endpoints) await _async_invoke_cc_api(endpoints)

View File

@ -39,12 +39,6 @@ from homeassistant.helpers.typing import ConfigType
# Platform type should be <DOMAIN>.<SUBMODULE_NAME> # Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}" PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
EVENT_MODEL_MAP = {
"controller": CONTROLLER_EVENT_MODEL_MAP,
"driver": DRIVER_EVENT_MODEL_MAP,
"node": NODE_EVENT_MODEL_MAP,
}
def validate_non_node_event_source(obj: dict) -> dict: def validate_non_node_event_source(obj: dict) -> dict:
"""Validate that a trigger for a non node event source has a config entry.""" """Validate that a trigger for a non node event source has a config entry."""
@ -58,7 +52,12 @@ def validate_event_name(obj: dict) -> dict:
event_source = obj[ATTR_EVENT_SOURCE] event_source = obj[ATTR_EVENT_SOURCE]
event_name = obj[ATTR_EVENT] event_name = obj[ATTR_EVENT]
# the keys to the event source's model map are the event names # the keys to the event source's model map are the event names
vol.In(EVENT_MODEL_MAP[event_source])(event_name) if event_source == "controller":
vol.In(CONTROLLER_EVENT_MODEL_MAP)(event_name)
elif event_source == "driver":
vol.In(DRIVER_EVENT_MODEL_MAP)(event_name)
else:
vol.In(NODE_EVENT_MODEL_MAP)(event_name)
return obj return obj
@ -68,11 +67,16 @@ def validate_event_data(obj: dict) -> dict:
if ATTR_EVENT_DATA not in obj: if ATTR_EVENT_DATA not in obj:
return obj return obj
event_source = obj[ATTR_EVENT_SOURCE] event_source: str = obj[ATTR_EVENT_SOURCE]
event_name = obj[ATTR_EVENT] event_name: str = obj[ATTR_EVENT]
event_data = obj[ATTR_EVENT_DATA] event_data: dict = obj[ATTR_EVENT_DATA]
try: try:
EVENT_MODEL_MAP[event_source][event_name](**event_data) if event_source == "controller":
CONTROLLER_EVENT_MODEL_MAP[event_name](**event_data)
elif event_source == "driver":
DRIVER_EVENT_MODEL_MAP[event_name](**event_data)
else:
NODE_EVENT_MODEL_MAP[event_name](**event_data)
except ValidationError as exc: except ValidationError as exc:
# Filter out required field errors if keys can be missing, and if there are # Filter out required field errors if keys can be missing, and if there are
# still errors, raise an exception # still errors, raise an exception
@ -90,7 +94,7 @@ TRIGGER_SCHEMA = vol.All(
vol.Optional(ATTR_CONFIG_ENTRY_ID): str, vol.Optional(ATTR_CONFIG_ENTRY_ID): str,
vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]), vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Required(ATTR_EVENT_SOURCE): vol.In(EVENT_MODEL_MAP), vol.Required(ATTR_EVENT_SOURCE): vol.In(["controller", "driver", "node"]),
vol.Required(ATTR_EVENT): cv.string, vol.Required(ATTR_EVENT): cv.string,
vol.Optional(ATTR_EVENT_DATA): dict, vol.Optional(ATTR_EVENT_DATA): dict,
vol.Optional(ATTR_PARTIAL_DICT_MATCH, default=False): bool, vol.Optional(ATTR_PARTIAL_DICT_MATCH, default=False): bool,
@ -200,11 +204,11 @@ async def async_attach_trigger(
if not nodes: if not nodes:
entry_id = config[ATTR_CONFIG_ENTRY_ID] entry_id = config[ATTR_CONFIG_ENTRY_ID]
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT] client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
assert client.driver
if event_source == "controller": if event_source == "controller":
source = client.driver.controller unsubs.append(client.driver.controller.on(event_name, async_on_event))
else: else:
source = client.driver unsubs.append(client.driver.on(event_name, async_on_event))
unsubs.append(source.on(event_name, async_on_event))
for node in nodes: for node in nodes:
driver = node.client.driver driver = node.client.driver

View File

@ -5,7 +5,6 @@ import functools
import voluptuous as vol import voluptuous as vol
from zwave_js_server.const import CommandClass from zwave_js_server.const import CommandClass
from zwave_js_server.event import Event
from zwave_js_server.model.node import Node from zwave_js_server.model.node import Node
from zwave_js_server.model.value import Value, get_value_id from zwave_js_server.model.value import Value, get_value_id
@ -108,7 +107,7 @@ async def async_attach_trigger(
@callback @callback
def async_on_value_updated( def async_on_value_updated(
value: Value, device: dr.DeviceEntry, event: Event value: Value, device: dr.DeviceEntry, event: dict
) -> None: ) -> None:
"""Handle value update.""" """Handle value update."""
event_value: Value = event["value"] event_value: Value = event["value"]

View File

@ -5,7 +5,10 @@ import pytest
from zwave_js_server.event import Event from zwave_js_server.event import Event
from homeassistant.components.diagnostics.const import REDACTED from homeassistant.components.diagnostics.const import REDACTED
from homeassistant.components.zwave_js.diagnostics import async_get_device_diagnostics from homeassistant.components.zwave_js.diagnostics import (
ZwaveValueMatcher,
async_get_device_diagnostics,
)
from homeassistant.components.zwave_js.discovery import async_discover_node_values from homeassistant.components.zwave_js.discovery import async_discover_node_values
from homeassistant.components.zwave_js.helpers import get_device_id from homeassistant.components.zwave_js.helpers import get_device_id
from homeassistant.helpers.device_registry import async_get from homeassistant.helpers.device_registry import async_get
@ -100,3 +103,9 @@ async def test_device_diagnostics_error(hass, integration):
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
await async_get_device_diagnostics(hass, integration, device) await async_get_device_diagnostics(hass, integration, device)
async def test_empty_zwave_value_matcher():
"""Test empty ZwaveValueMatcher is invalid."""
with pytest.raises(ValueError):
ZwaveValueMatcher()

View File

@ -1,10 +0,0 @@
"""Test Z-Wave JS helpers module."""
import pytest
from homeassistant.components.zwave_js.helpers import ZwaveValueID
async def test_empty_zwave_value_id():
"""Test empty ZwaveValueID is invalid."""
with pytest.raises(ValueError):
ZwaveValueID()