"""Helper functions for Z-Wave JS integration.""" from __future__ import annotations import asyncio from collections.abc import Callable, Coroutine from dataclasses import astuple, dataclass import logging from typing import Any, cast import aiohttp import voluptuous as vol from zwave_js_server.const import ( LOG_LEVEL_MAP, CommandClass, ConfigurationValueType, LogLevel, ) from zwave_js_server.model.controller import Controller, ProvisioningEntry from zwave_js_server.model.driver import Driver from zwave_js_server.model.log_config import LogConfig from zwave_js_server.model.node import Node as ZwaveNode from zwave_js_server.model.value import ( ConfigurationValue, Value as ZwaveValue, ValueDataType, get_value_id_str, ) from zwave_js_server.version import VersionInfo, get_server_version from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.const import ( ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_TYPE, __version__ as HA_VERSION, ) from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.group import expand_entity_ids from homeassistant.helpers.typing import ConfigType, VolSchemaType from .const import ( ATTR_COMMAND_CLASS, ATTR_ENDPOINT, ATTR_PROPERTY, ATTR_PROPERTY_KEY, DOMAIN, LIB_LOGGER, LOGGER, ) from .models import ZwaveJSConfigEntry DRIVER_READY_EVENT_TIMEOUT = 60 SERVER_VERSION_TIMEOUT = 10 @dataclass class ZwaveValueMatcher: """Class to allow matching a Z-Wave Value.""" property_: str | int | None = None command_class: int | None = None endpoint: int | None = None property_key: str | int | None = None def __post_init__(self) -> None: """Post initialization check.""" if all(val is None for val in astuple(self)): raise ValueError("At least one of the fields must be set.") def value_matches_matcher( matcher: ZwaveValueMatcher, value_data: ValueDataType ) -> bool: """Return whether value matches matcher.""" command_class = None if "commandClass" in value_data: command_class = CommandClass(value_data["commandClass"]) zwave_value_id = ZwaveValueMatcher( property_=value_data.get("property"), command_class=command_class, endpoint=value_data.get("endpoint"), property_key=value_data.get("propertyKey"), ) return all( redacted_field_val is None or redacted_field_val == zwave_value_field_val for redacted_field_val, zwave_value_field_val in zip( astuple(matcher), astuple(zwave_value_id), strict=False ) ) def get_value_id_from_unique_id(unique_id: str) -> str | None: """Get the value ID and optional state key from a unique ID. Raises ValueError """ split_unique_id = unique_id.split(".") # If the unique ID contains a `-` in its second part, the unique ID contains # a value ID and we can return it. if "-" in (value_id := split_unique_id[1]): return value_id return None def get_state_key_from_unique_id(unique_id: str) -> int | None: """Get the state key from a unique ID.""" # If the unique ID has more than two parts, it's a special unique ID. If the last # part of the unique ID is an int, then it's a state key and we return it. if len(split_unique_id := unique_id.split(".")) > 2: try: return int(split_unique_id[-1]) except ValueError: pass return None def get_value_of_zwave_value(value: ZwaveValue | None) -> Any | None: """Return the value of a ZwaveValue.""" return value.value if value else None async def async_enable_statistics(driver: Driver) -> None: """Enable statistics on the driver.""" await driver.async_enable_statistics("Home Assistant", HA_VERSION) async def async_enable_server_logging_if_needed( hass: HomeAssistant, entry: ZwaveJSConfigEntry, driver: Driver ) -> None: """Enable logging of zwave-js-server in the lib.""" # If lib log level is set to debug, we want to enable server logging. First we # check if server log level is less verbose than library logging, and if so, set it # to debug to match library logging. We will store the old server log level in # hass.data so we can reset it later if ( not driver or not driver.client.connected or driver.client.server_logging_enabled ): return LOGGER.info("Enabling zwave-js-server logging") if (curr_server_log_level := driver.log_config.level) and ( LOG_LEVEL_MAP[curr_server_log_level] ) > LIB_LOGGER.getEffectiveLevel(): entry.runtime_data.old_server_log_level = curr_server_log_level await driver.async_update_log_config(LogConfig(level=LogLevel.DEBUG)) await driver.client.enable_server_logging() LOGGER.info("Zwave-js-server logging is enabled") async def async_disable_server_logging_if_needed( hass: HomeAssistant, entry: ZwaveJSConfigEntry, driver: Driver ) -> None: """Disable logging of zwave-js-server in the lib if still connected to server.""" if ( not driver or not driver.client.connected or not driver.client.server_logging_enabled ): return LOGGER.info("Disabling zwave_js server logging") if ( old_server_log_level := entry.runtime_data.old_server_log_level ) is not None and old_server_log_level != driver.log_config.level: LOGGER.info( ( "Server logging is currently set to %s as a result of server logging " "being enabled. It is now being reset to %s" ), driver.log_config.level, old_server_log_level, ) await driver.async_update_log_config(LogConfig(level=old_server_log_level)) entry.runtime_data.old_server_log_level = None driver.client.disable_server_logging() LOGGER.info("Zwave-js-server logging is enabled") def get_valueless_base_unique_id(driver: Driver, node: ZwaveNode) -> str: """Return the base unique ID for an entity that is not based on a value.""" return f"{driver.controller.home_id}.{node.node_id}" def get_unique_id(driver: Driver, value_id: str) -> str: """Get unique ID from client and value ID.""" return f"{driver.controller.home_id}.{value_id}" def get_device_id(driver: Driver, node: ZwaveNode) -> tuple[str, str]: """Get device registry identifier for Z-Wave node.""" return (DOMAIN, f"{driver.controller.home_id}-{node.node_id}") def get_device_id_ext(driver: Driver, node: ZwaveNode) -> tuple[str, str] | None: """Get extended device registry identifier for Z-Wave node.""" if None in (node.manufacturer_id, node.product_type, node.product_id): return None domain, dev_id = get_device_id(driver, node) return ( domain, f"{dev_id}-{node.manufacturer_id}:{node.product_type}:{node.product_id}", ) def get_home_and_node_id_from_device_entry( device_entry: dr.DeviceEntry, ) -> tuple[str, int] | None: """Get home ID and node ID for Z-Wave device registry entry. Returns (home_id, node_id) or None if not found. """ device_id = next( ( identifier[1] for identifier in device_entry.identifiers if identifier[0] == DOMAIN ), None, ) if device_id is None or device_id.startswith("provision_"): return None id_ = device_id.split("-") return (id_[0], int(id_[1])) @callback def async_get_node_from_device_id( hass: HomeAssistant, device_id: str, dev_reg: dr.DeviceRegistry | None = None ) -> ZwaveNode: """Get node from a device ID. Raises ValueError if device is invalid or node can't be found. """ if not dev_reg: dev_reg = dr.async_get(hass) if not (device_entry := dev_reg.async_get(device_id)): raise ValueError(f"Device ID {device_id} is not valid") # Use device config entry ID's to validate that this is a valid zwave_js device # and to get the client config_entry_ids = device_entry.config_entries entry: ZwaveJSConfigEntry | None = next( ( entry for entry in hass.config_entries.async_entries(DOMAIN) if entry.entry_id in config_entry_ids ), None, ) if entry is None: raise ValueError( f"Device {device_id} is not from an existing zwave_js config entry" ) if entry.state != ConfigEntryState.LOADED: raise ValueError(f"Device {device_id} config entry is not loaded") client = entry.runtime_data.client driver = client.driver if driver is None: raise ValueError("Driver is not ready.") # Get node ID from device identifier, perform some validation, and then get the # node identifiers = get_home_and_node_id_from_device_entry(device_entry) node_id = identifiers[1] if identifiers else None if node_id is None or node_id not in driver.controller.nodes: raise ValueError(f"Node for device {device_id} can't be found") return driver.controller.nodes[node_id] async def async_get_provisioning_entry_from_device_id( hass: HomeAssistant, device_id: str ) -> ProvisioningEntry | None: """Get provisioning entry from a device ID. Raises ValueError if device is invalid """ dev_reg = dr.async_get(hass) if not (device_entry := dev_reg.async_get(device_id)): raise ValueError(f"Device ID {device_id} is not valid") # Use device config entry ID's to validate that this is a valid zwave_js device # and to get the client config_entry_ids = device_entry.config_entries entry: ZwaveJSConfigEntry | None = next( ( entry for entry in hass.config_entries.async_entries(DOMAIN) if entry.entry_id in config_entry_ids ), None, ) if entry is None: raise ValueError( f"Device {device_id} is not from an existing zwave_js config entry" ) if entry.state != ConfigEntryState.LOADED: raise ValueError(f"Device {device_id} config entry is not loaded") client = entry.runtime_data.client driver = client.driver if driver is None: raise ValueError("Driver is not ready.") provisioning_entries = await driver.controller.async_get_provisioning_entries() for provisioning_entry in provisioning_entries: if ( provisioning_entry.additional_properties and provisioning_entry.additional_properties.get("device_id") == device_id ): return provisioning_entry return None @callback def async_get_node_from_entity_id( hass: HomeAssistant, entity_id: str, ent_reg: er.EntityRegistry | None = None, dev_reg: dr.DeviceRegistry | None = None, ) -> ZwaveNode: """Get node from an entity ID. Raises ValueError if entity is invalid. """ if not ent_reg: ent_reg = er.async_get(hass) entity_entry = ent_reg.async_get(entity_id) if entity_entry is None or entity_entry.platform != DOMAIN: raise ValueError(f"Entity {entity_id} is not a valid {DOMAIN} entity") # Assert for mypy, safe because we know that zwave_js entities are always # tied to a device assert entity_entry.device_id return async_get_node_from_device_id(hass, entity_entry.device_id, dev_reg) @callback def async_get_nodes_from_area_id( hass: HomeAssistant, area_id: str, ent_reg: er.EntityRegistry | None = None, dev_reg: dr.DeviceRegistry | None = None, ) -> set[ZwaveNode]: """Get nodes for all Z-Wave JS devices and entities that are in an area.""" nodes: set[ZwaveNode] = set() if ent_reg is None: ent_reg = er.async_get(hass) if dev_reg is None: dev_reg = dr.async_get(hass) # Add devices for all entities in an area that are Z-Wave JS entities nodes.update( { async_get_node_from_device_id(hass, entity.device_id, dev_reg) for entity in er.async_entries_for_area(ent_reg, area_id) if entity.platform == DOMAIN and entity.device_id is not None } ) # Add devices in an area that are Z-Wave JS devices nodes.update( async_get_node_from_device_id(hass, device.id, dev_reg) for device in dr.async_entries_for_area(dev_reg, area_id) if any( cast( ZwaveJSConfigEntry, hass.config_entries.async_get_entry(config_entry_id), ).domain == DOMAIN for config_entry_id in device.config_entries ) ) return nodes @callback def async_get_nodes_from_targets( hass: HomeAssistant, val: dict[str, Any], ent_reg: er.EntityRegistry | None = None, dev_reg: dr.DeviceRegistry | None = None, logger: logging.Logger = LOGGER, ) -> set[ZwaveNode]: """Get nodes for all targets. Supports entity_id with group expansion, area_id, and device_id. """ nodes: set[ZwaveNode] = set() # Convert all entity IDs to nodes for entity_id in expand_entity_ids(hass, val.get(ATTR_ENTITY_ID, [])): try: nodes.add(async_get_node_from_entity_id(hass, entity_id, ent_reg, dev_reg)) except ValueError as err: logger.warning(err.args[0]) # Convert all area IDs to nodes for area_id in val.get(ATTR_AREA_ID, []): nodes.update(async_get_nodes_from_area_id(hass, area_id, ent_reg, dev_reg)) # Convert all device IDs to nodes for device_id in val.get(ATTR_DEVICE_ID, []): try: nodes.add(async_get_node_from_device_id(hass, device_id, dev_reg)) except ValueError as err: logger.warning(err.args[0]) return nodes def get_zwave_value_from_config(node: ZwaveNode, config: ConfigType) -> ZwaveValue: """Get a Z-Wave JS Value from a config.""" endpoint = None if config.get(ATTR_ENDPOINT): endpoint = config[ATTR_ENDPOINT] property_key = None if config.get(ATTR_PROPERTY_KEY): property_key = config[ATTR_PROPERTY_KEY] value_id = get_value_id_str( node, config[ATTR_COMMAND_CLASS], config[ATTR_PROPERTY], endpoint, property_key, ) if value_id not in node.values: raise vol.Invalid(f"Value {value_id} can't be found on node {node}") return node.values[value_id] def _zwave_js_config_entry(hass: HomeAssistant, device: dr.DeviceEntry) -> str | None: """Find zwave_js config entry from a device.""" for entry_id in device.config_entries: entry = hass.config_entries.async_get_entry(entry_id) if entry and entry.domain == DOMAIN: return entry_id return None @callback def async_get_node_status_sensor_entity_id( hass: HomeAssistant, device_id: str, ent_reg: er.EntityRegistry | None = None, dev_reg: dr.DeviceRegistry | None = None, ) -> str | None: """Get the node status sensor entity ID for a given Z-Wave JS device.""" if not ent_reg: ent_reg = er.async_get(hass) if not dev_reg: dev_reg = dr.async_get(hass) if not (device := dev_reg.async_get(device_id)): raise HomeAssistantError("Invalid Device ID provided") if not (entry_id := _zwave_js_config_entry(hass, device)): return None entry = hass.config_entries.async_get_entry(entry_id) assert entry client = entry.runtime_data.client node = async_get_node_from_device_id(hass, device_id, dev_reg) return ent_reg.async_get_entity_id( SENSOR_DOMAIN, DOMAIN, f"{client.driver.controller.home_id}.{node.node_id}.node_status", ) def remove_keys_with_empty_values(config: ConfigType) -> ConfigType: """Remove keys from config where the value is an empty string or None.""" return {key: value for key, value in config.items() if value not in ("", None)} def check_type_schema_map( schema_map: dict[str, vol.Schema], ) -> Callable[[ConfigType], ConfigType]: """Check type specific schema against config.""" def _check_type_schema(config: ConfigType) -> ConfigType: """Check type specific schema against config.""" return cast(ConfigType, schema_map[str(config[CONF_TYPE])](config)) return _check_type_schema def copy_available_params( input_dict: dict[str, Any], output_dict: dict[str, Any], params: list[str] ) -> None: """Copy available params from input into output.""" output_dict.update( {param: input_dict[param] for param in params if param in input_dict} ) def get_value_state_schema( value: ZwaveValue, ) -> VolSchemaType | vol.Coerce | vol.In | None: """Return device automation schema for a config entry.""" if isinstance(value, ConfigurationValue): min_ = value.metadata.min max_ = value.metadata.max if value.configuration_value_type in ( ConfigurationValueType.RANGE, ConfigurationValueType.MANUAL_ENTRY, ): return vol.All(vol.Coerce(int), vol.Range(min=min_, max=max_)) if value.configuration_value_type == ConfigurationValueType.BOOLEAN: return vol.Coerce(bool) if value.configuration_value_type == ConfigurationValueType.ENUMERATED: return vol.In({int(k): v for k, v in value.metadata.states.items()}) return None if value.metadata.states: return vol.In({int(k): v for k, v in value.metadata.states.items()}) return vol.All( vol.Coerce(int), vol.Range(min=value.metadata.min, max=value.metadata.max), ) def get_device_info(driver: Driver, node: ZwaveNode) -> DeviceInfo: """Get DeviceInfo for node.""" return DeviceInfo( identifiers={get_device_id(driver, node)}, sw_version=node.firmware_version, name=node.name or node.device_config.description or f"Node {node.node_id}", model=node.device_config.label, manufacturer=node.device_config.manufacturer, suggested_area=node.location if node.location else None, ) def get_network_identifier_for_notification( hass: HomeAssistant, config_entry: ZwaveJSConfigEntry, controller: Controller ) -> str: """Return the network identifier string for persistent notifications.""" home_id = str(controller.home_id) if len(hass.config_entries.async_entries(DOMAIN)) > 1: if str(home_id) != config_entry.title: return f"`{config_entry.title}`, with the home ID `{home_id}`," return f"with the home ID `{home_id}`" return "" async def async_get_version_info(hass: HomeAssistant, ws_address: str) -> VersionInfo: """Return Z-Wave JS version info.""" try: async with asyncio.timeout(SERVER_VERSION_TIMEOUT): version_info: VersionInfo = await get_server_version( ws_address, async_get_clientsession(hass) ) except (TimeoutError, aiohttp.ClientError) as err: # We don't want to spam the log if the add-on isn't started # or takes a long time to start. LOGGER.debug("Failed to connect to Z-Wave JS server: %s", err) raise CannotConnect from err return version_info @callback def async_wait_for_driver_ready_event( config_entry: ZwaveJSConfigEntry, driver: Driver, ) -> Callable[[], Coroutine[Any, Any, None]]: """Wait for the driver ready event and the config entry reload. When the driver ready event is received the config entry will be reloaded by the integration. This function helps wait for that to happen before proceeding with further actions. If the config entry is reloaded for another reason, this function will not wait for it to be reloaded again. Raises TimeoutError if the driver ready event and reload is not received within the specified timeout. """ driver_ready_event_received = asyncio.Event() config_entry_reloaded = asyncio.Event() unsubscribers: list[Callable[[], None]] = [] @callback def driver_ready_received(event: dict) -> None: """Receive the driver ready event.""" driver_ready_event_received.set() unsubscribers.append(driver.once("driver ready", driver_ready_received)) @callback def on_config_entry_state_change() -> None: """Check config entry was loaded after driver ready event.""" if config_entry.state is ConfigEntryState.LOADED: config_entry_reloaded.set() unsubscribers.append( config_entry.async_on_state_change(on_config_entry_state_change) ) async def wait_for_events() -> None: try: async with asyncio.timeout(DRIVER_READY_EVENT_TIMEOUT): await asyncio.gather( driver_ready_event_received.wait(), config_entry_reloaded.wait() ) finally: for unsubscribe in unsubscribers: unsubscribe() return wait_for_events class CannotConnect(HomeAssistantError): """Indicate connection error."""