mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Fix more typing for zwave_js (#72472)
* Fix more typing for zwave_js * Revert one change * reduce lines * Fix tests
This commit is contained in:
parent
209f37196e
commit
eb7a521232
@ -146,7 +146,7 @@ async def async_get_actions(
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List device actions for Z-Wave JS devices."""
|
||||
registry = entity_registry.async_get(hass)
|
||||
actions = []
|
||||
actions: list[dict] = []
|
||||
|
||||
node = async_get_node_from_device_id(hass, device_id)
|
||||
|
||||
@ -207,10 +207,13 @@ async def async_get_actions(
|
||||
# If the value has the meterType CC specific value, we can add a reset_meter
|
||||
# action for it
|
||||
if CC_SPECIFIC_METER_TYPE in value.metadata.cc_specific:
|
||||
meter_endpoints[value.endpoint].setdefault(
|
||||
endpoint_idx = value.endpoint
|
||||
if endpoint_idx is None:
|
||||
endpoint_idx = 0
|
||||
meter_endpoints[endpoint_idx].setdefault(
|
||||
CONF_ENTITY_ID, entry.entity_id
|
||||
)
|
||||
meter_endpoints[value.endpoint].setdefault(ATTR_METER_TYPE, set()).add(
|
||||
meter_endpoints[endpoint_idx].setdefault(ATTR_METER_TYPE, set()).add(
|
||||
get_meter_type(value)
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,8 @@ def generate_config_parameter_subtype(config_value: ConfigurationValue) -> str:
|
||||
"""Generate the config parameter name used in a device automation subtype."""
|
||||
parameter = str(config_value.property_)
|
||||
if config_value.property_key:
|
||||
# Property keys for config values are always an int
|
||||
assert isinstance(config_value.property_key, int)
|
||||
parameter = f"{parameter}[{hex(config_value.property_key)}]"
|
||||
|
||||
return f"{parameter} ({config_value.property_name})"
|
||||
|
@ -254,7 +254,7 @@ async def async_get_triggers(
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
node = async_get_node_from_device_id(hass, device_id, dev_reg)
|
||||
|
||||
triggers = []
|
||||
triggers: list[dict] = []
|
||||
base_trigger = {
|
||||
CONF_PLATFORM: "device",
|
||||
CONF_DEVICE_ID: device_id,
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Provides diagnostics for Z-Wave JS."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import astuple
|
||||
from copy import deepcopy
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any
|
||||
|
||||
from zwave_js_server.client import Client
|
||||
@ -20,25 +21,43 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from .const import DATA_CLIENT, DOMAIN
|
||||
from .helpers import (
|
||||
ZwaveValueID,
|
||||
get_home_and_node_id_from_device_entry,
|
||||
get_state_key_from_unique_id,
|
||||
get_value_id_from_unique_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZwaveValueMatcher:
|
||||
"""Class to allow matching a Z-Wave Value."""
|
||||
|
||||
property_: str | int | None = None
|
||||
command_class: int | None = None
|
||||
endpoint: int | None = None
|
||||
property_key: str | int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Post initialization check."""
|
||||
if all(val is None for val in astuple(self)):
|
||||
raise ValueError("At least one of the fields must be set.")
|
||||
|
||||
|
||||
KEYS_TO_REDACT = {"homeId", "location"}
|
||||
|
||||
VALUES_TO_REDACT = (
|
||||
ZwaveValueID(property_="userCode", command_class=CommandClass.USER_CODE),
|
||||
ZwaveValueMatcher(property_="userCode", command_class=CommandClass.USER_CODE),
|
||||
)
|
||||
|
||||
|
||||
def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType:
|
||||
"""Redact value of a Z-Wave value."""
|
||||
for value_to_redact in VALUES_TO_REDACT:
|
||||
zwave_value_id = ZwaveValueID(
|
||||
property_=zwave_value["property"],
|
||||
command_class=CommandClass(zwave_value["commandClass"]),
|
||||
command_class = None
|
||||
if "commandClass" in zwave_value:
|
||||
command_class = CommandClass(zwave_value["commandClass"])
|
||||
zwave_value_id = ZwaveValueMatcher(
|
||||
property_=zwave_value.get("property"),
|
||||
command_class=command_class,
|
||||
endpoint=zwave_value.get("endpoint"),
|
||||
property_key=zwave_value.get("propertyKey"),
|
||||
)
|
||||
@ -48,19 +67,19 @@ def redact_value_of_zwave_value(zwave_value: ValueDataType) -> ValueDataType:
|
||||
astuple(value_to_redact), astuple(zwave_value_id)
|
||||
)
|
||||
):
|
||||
return {**zwave_value, "value": REDACTED}
|
||||
redacted_value: ValueDataType = deepcopy(zwave_value)
|
||||
redacted_value["value"] = REDACTED
|
||||
return redacted_value
|
||||
return zwave_value
|
||||
|
||||
|
||||
def redact_node_state(node_state: NodeDataType) -> NodeDataType:
|
||||
"""Redact node state."""
|
||||
return {
|
||||
**node_state,
|
||||
"values": [
|
||||
redact_value_of_zwave_value(zwave_value)
|
||||
for zwave_value in node_state["values"]
|
||||
],
|
||||
}
|
||||
redacted_state: NodeDataType = deepcopy(node_state)
|
||||
redacted_state["values"] = [
|
||||
redact_value_of_zwave_value(zwave_value) for zwave_value in node_state["values"]
|
||||
]
|
||||
return redacted_state
|
||||
|
||||
|
||||
def get_device_entities(
|
||||
@ -125,15 +144,17 @@ async def async_get_config_entry_diagnostics(
|
||||
|
||||
async def async_get_device_diagnostics(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, device: dr.DeviceEntry
|
||||
) -> NodeDataType:
|
||||
) -> dict:
|
||||
"""Return diagnostics for a device."""
|
||||
client: Client = hass.data[DOMAIN][config_entry.entry_id][DATA_CLIENT]
|
||||
identifiers = get_home_and_node_id_from_device_entry(device)
|
||||
node_id = identifiers[1] if identifiers else None
|
||||
if node_id is None or node_id not in client.driver.controller.nodes:
|
||||
assert (driver := client.driver)
|
||||
if node_id is None or node_id not in driver.controller.nodes:
|
||||
raise ValueError(f"Node for device {device.id} can't be found")
|
||||
node = client.driver.controller.nodes[node_id]
|
||||
node = driver.controller.nodes[node_id]
|
||||
entities = get_device_entities(hass, node, device)
|
||||
assert client.version
|
||||
return {
|
||||
"versionInfo": {
|
||||
"driverVersion": client.version.driver_version,
|
||||
|
@ -406,7 +406,7 @@ class TiltValueMix:
|
||||
class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix):
|
||||
"""Tilt data template class for Z-Wave Cover entities."""
|
||||
|
||||
def resolve_data(self, value: ZwaveValue) -> dict[str, Any]:
|
||||
def resolve_data(self, value: ZwaveValue) -> dict[str, ZwaveValue | None]:
|
||||
"""Resolve helper class data for a discovered value."""
|
||||
return {"tilt_value": self._get_value_from_id(value.node, self.tilt_value_id)}
|
||||
|
||||
@ -415,7 +415,9 @@ class CoverTiltDataTemplate(BaseDiscoverySchemaDataTemplate, TiltValueMix):
|
||||
return [resolved_data["tilt_value"]]
|
||||
|
||||
@staticmethod
|
||||
def current_tilt_value(resolved_data: dict[str, Any]) -> ZwaveValue | None:
|
||||
def current_tilt_value(
|
||||
resolved_data: dict[str, ZwaveValue | None]
|
||||
) -> ZwaveValue | None:
|
||||
"""Get current tilt ZwaveValue from resolved data."""
|
||||
return resolved_data["tilt_value"]
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import astuple, dataclass
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
@ -48,16 +48,11 @@ from .const import (
|
||||
class ZwaveValueID:
|
||||
"""Class to represent a value ID."""
|
||||
|
||||
property_: str | int | None = None
|
||||
command_class: int | None = None
|
||||
property_: str | int
|
||||
command_class: int
|
||||
endpoint: int | None = None
|
||||
property_key: str | int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Post initialization check."""
|
||||
if all(val is None for val in astuple(self)):
|
||||
raise ValueError("At least one of the fields must be set.")
|
||||
|
||||
|
||||
@callback
|
||||
def get_value_id_from_unique_id(unique_id: str) -> str | None:
|
||||
|
@ -400,7 +400,7 @@ class ZWaveServices:
|
||||
|
||||
async def async_set_config_parameter(self, service: ServiceCall) -> None:
|
||||
"""Set a config value on a node."""
|
||||
nodes = service.data[const.ATTR_NODES]
|
||||
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
|
||||
property_or_property_name = service.data[const.ATTR_CONFIG_PARAMETER]
|
||||
property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK)
|
||||
new_value = service.data[const.ATTR_CONFIG_VALUE]
|
||||
@ -434,7 +434,7 @@ class ZWaveServices:
|
||||
self, service: ServiceCall
|
||||
) -> None:
|
||||
"""Bulk set multiple partial config values on a node."""
|
||||
nodes = service.data[const.ATTR_NODES]
|
||||
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
|
||||
property_ = service.data[const.ATTR_CONFIG_PARAMETER]
|
||||
new_value = service.data[const.ATTR_CONFIG_VALUE]
|
||||
|
||||
@ -531,7 +531,7 @@ class ZWaveServices:
|
||||
|
||||
async def async_multicast_set_value(self, service: ServiceCall) -> None:
|
||||
"""Set a value via multicast to multiple nodes."""
|
||||
nodes = service.data[const.ATTR_NODES]
|
||||
nodes: set[ZwaveNode] = service.data[const.ATTR_NODES]
|
||||
broadcast: bool = service.data[const.ATTR_BROADCAST]
|
||||
options = service.data.get(const.ATTR_OPTIONS)
|
||||
|
||||
@ -559,13 +559,15 @@ class ZWaveServices:
|
||||
# If there are no nodes, we can assume there is only one config entry due to
|
||||
# schema validation and can use that to get the client, otherwise we can just
|
||||
# get the client from the node.
|
||||
client: ZwaveClient = None
|
||||
first_node: ZwaveNode = next((node for node in nodes), None)
|
||||
if first_node:
|
||||
client: ZwaveClient
|
||||
first_node: ZwaveNode
|
||||
try:
|
||||
first_node = next(node for node in nodes)
|
||||
client = first_node.client
|
||||
else:
|
||||
except StopIteration:
|
||||
entry_id = self._hass.config_entries.async_entries(const.DOMAIN)[0].entry_id
|
||||
client = self._hass.data[const.DOMAIN][entry_id][const.DATA_CLIENT]
|
||||
assert client.driver
|
||||
first_node = next(
|
||||
node
|
||||
for node in client.driver.controller.nodes.values()
|
||||
@ -688,6 +690,9 @@ class ZWaveServices:
|
||||
_LOGGER.warning("Skipping entity %s as it has no value ID", entity_id)
|
||||
continue
|
||||
|
||||
endpoints.add(node.endpoints[node.values[value_id].endpoint])
|
||||
endpoint_idx = node.values[value_id].endpoint
|
||||
endpoints.add(
|
||||
node.endpoints[endpoint_idx if endpoint_idx is not None else 0]
|
||||
)
|
||||
|
||||
await _async_invoke_cc_api(endpoints)
|
||||
|
@ -39,12 +39,6 @@ from homeassistant.helpers.typing import ConfigType
|
||||
# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
|
||||
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
|
||||
|
||||
EVENT_MODEL_MAP = {
|
||||
"controller": CONTROLLER_EVENT_MODEL_MAP,
|
||||
"driver": DRIVER_EVENT_MODEL_MAP,
|
||||
"node": NODE_EVENT_MODEL_MAP,
|
||||
}
|
||||
|
||||
|
||||
def validate_non_node_event_source(obj: dict) -> dict:
|
||||
"""Validate that a trigger for a non node event source has a config entry."""
|
||||
@ -58,7 +52,12 @@ def validate_event_name(obj: dict) -> dict:
|
||||
event_source = obj[ATTR_EVENT_SOURCE]
|
||||
event_name = obj[ATTR_EVENT]
|
||||
# the keys to the event source's model map are the event names
|
||||
vol.In(EVENT_MODEL_MAP[event_source])(event_name)
|
||||
if event_source == "controller":
|
||||
vol.In(CONTROLLER_EVENT_MODEL_MAP)(event_name)
|
||||
elif event_source == "driver":
|
||||
vol.In(DRIVER_EVENT_MODEL_MAP)(event_name)
|
||||
else:
|
||||
vol.In(NODE_EVENT_MODEL_MAP)(event_name)
|
||||
return obj
|
||||
|
||||
|
||||
@ -68,11 +67,16 @@ def validate_event_data(obj: dict) -> dict:
|
||||
if ATTR_EVENT_DATA not in obj:
|
||||
return obj
|
||||
|
||||
event_source = obj[ATTR_EVENT_SOURCE]
|
||||
event_name = obj[ATTR_EVENT]
|
||||
event_data = obj[ATTR_EVENT_DATA]
|
||||
event_source: str = obj[ATTR_EVENT_SOURCE]
|
||||
event_name: str = obj[ATTR_EVENT]
|
||||
event_data: dict = obj[ATTR_EVENT_DATA]
|
||||
try:
|
||||
EVENT_MODEL_MAP[event_source][event_name](**event_data)
|
||||
if event_source == "controller":
|
||||
CONTROLLER_EVENT_MODEL_MAP[event_name](**event_data)
|
||||
elif event_source == "driver":
|
||||
DRIVER_EVENT_MODEL_MAP[event_name](**event_data)
|
||||
else:
|
||||
NODE_EVENT_MODEL_MAP[event_name](**event_data)
|
||||
except ValidationError as exc:
|
||||
# Filter out required field errors if keys can be missing, and if there are
|
||||
# still errors, raise an exception
|
||||
@ -90,7 +94,7 @@ TRIGGER_SCHEMA = vol.All(
|
||||
vol.Optional(ATTR_CONFIG_ENTRY_ID): str,
|
||||
vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
|
||||
vol.Required(ATTR_EVENT_SOURCE): vol.In(EVENT_MODEL_MAP),
|
||||
vol.Required(ATTR_EVENT_SOURCE): vol.In(["controller", "driver", "node"]),
|
||||
vol.Required(ATTR_EVENT): cv.string,
|
||||
vol.Optional(ATTR_EVENT_DATA): dict,
|
||||
vol.Optional(ATTR_PARTIAL_DICT_MATCH, default=False): bool,
|
||||
@ -200,11 +204,11 @@ async def async_attach_trigger(
|
||||
if not nodes:
|
||||
entry_id = config[ATTR_CONFIG_ENTRY_ID]
|
||||
client: Client = hass.data[DOMAIN][entry_id][DATA_CLIENT]
|
||||
assert client.driver
|
||||
if event_source == "controller":
|
||||
source = client.driver.controller
|
||||
unsubs.append(client.driver.controller.on(event_name, async_on_event))
|
||||
else:
|
||||
source = client.driver
|
||||
unsubs.append(source.on(event_name, async_on_event))
|
||||
unsubs.append(client.driver.on(event_name, async_on_event))
|
||||
|
||||
for node in nodes:
|
||||
driver = node.client.driver
|
||||
|
@ -5,7 +5,6 @@ import functools
|
||||
|
||||
import voluptuous as vol
|
||||
from zwave_js_server.const import CommandClass
|
||||
from zwave_js_server.event import Event
|
||||
from zwave_js_server.model.node import Node
|
||||
from zwave_js_server.model.value import Value, get_value_id
|
||||
|
||||
@ -108,7 +107,7 @@ async def async_attach_trigger(
|
||||
|
||||
@callback
|
||||
def async_on_value_updated(
|
||||
value: Value, device: dr.DeviceEntry, event: Event
|
||||
value: Value, device: dr.DeviceEntry, event: dict
|
||||
) -> None:
|
||||
"""Handle value update."""
|
||||
event_value: Value = event["value"]
|
||||
|
@ -5,7 +5,10 @@ import pytest
|
||||
from zwave_js_server.event import Event
|
||||
|
||||
from homeassistant.components.diagnostics.const import REDACTED
|
||||
from homeassistant.components.zwave_js.diagnostics import async_get_device_diagnostics
|
||||
from homeassistant.components.zwave_js.diagnostics import (
|
||||
ZwaveValueMatcher,
|
||||
async_get_device_diagnostics,
|
||||
)
|
||||
from homeassistant.components.zwave_js.discovery import async_discover_node_values
|
||||
from homeassistant.components.zwave_js.helpers import get_device_id
|
||||
from homeassistant.helpers.device_registry import async_get
|
||||
@ -100,3 +103,9 @@ async def test_device_diagnostics_error(hass, integration):
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await async_get_device_diagnostics(hass, integration, device)
|
||||
|
||||
|
||||
async def test_empty_zwave_value_matcher():
|
||||
"""Test empty ZwaveValueMatcher is invalid."""
|
||||
with pytest.raises(ValueError):
|
||||
ZwaveValueMatcher()
|
||||
|
@ -1,10 +0,0 @@
|
||||
"""Test Z-Wave JS helpers module."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.zwave_js.helpers import ZwaveValueID
|
||||
|
||||
|
||||
async def test_empty_zwave_value_id():
|
||||
"""Test empty ZwaveValueID is invalid."""
|
||||
with pytest.raises(ValueError):
|
||||
ZwaveValueID()
|
Loading…
x
Reference in New Issue
Block a user