From 006fa9b7009835fc1b4a1b46de8ea82b4790fa21 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 30 Mar 2022 15:54:31 +0200 Subject: [PATCH] Improve zha typing [api] (5) (#68684) --- homeassistant/components/zha/api.py | 99 ++++++++++++--------- homeassistant/components/zha/core/device.py | 38 +++++--- homeassistant/components/zha/core/group.py | 12 ++- 3 files changed, 93 insertions(+), 56 deletions(-) diff --git a/homeassistant/components/zha/api.py b/homeassistant/components/zha/api.py index da094363e0a..552be260e8b 100644 --- a/homeassistant/components/zha/api.py +++ b/homeassistant/components/zha/api.py @@ -2,10 +2,8 @@ from __future__ import annotations import asyncio -import collections -from collections.abc import Mapping import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple import voluptuous as vol from zigpy.config.validators import cv_boolean @@ -14,7 +12,7 @@ from zigpy.zcl.clusters.security import IasAce import zigpy.zdo.types as zdo_types from homeassistant.components import websocket_api -from homeassistant.const import ATTR_COMMAND, ATTR_NAME +from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME from homeassistant.core import HomeAssistant, ServiceCall, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_connect @@ -31,6 +29,7 @@ from .core.const import ( ATTR_LEVEL, ATTR_MANUFACTURER, ATTR_MEMBERS, + ATTR_TYPE, ATTR_VALUE, ATTR_WARNING_DEVICE_DURATION, ATTR_WARNING_DEVICE_MODE, @@ -201,7 +200,56 @@ SERVICE_SCHEMAS = { ), } -ClusterBinding = collections.namedtuple("ClusterBinding", "id endpoint_id type name") + +class ClusterBinding(NamedTuple): + """Describes a cluster binding.""" + + name: str + type: str + id: int + endpoint_id: int + + +def _cv_group_member(value: dict[str, Any]) -> GroupMember: + """Transform a group member.""" + return GroupMember( + ieee=value[ATTR_IEEE], + endpoint_id=value[ATTR_ENDPOINT_ID], + ) + + +def _cv_cluster_binding(value: dict[str, Any]) -> ClusterBinding: + """Transform a cluster binding.""" + return ClusterBinding( + name=value[ATTR_NAME], + type=value[ATTR_TYPE], + id=value[ATTR_ID], + endpoint_id=value[ATTR_ENDPOINT_ID], + ) + + +GROUP_MEMBER_SCHEMA = vol.All( + vol.Schema( + { + vol.Required(ATTR_IEEE): IEEE_SCHEMA, + vol.Required(ATTR_ENDPOINT_ID): int, + } + ), + _cv_group_member, +) + + +CLUSTER_BINDING_SCHEMA = vol.All( + vol.Schema( + { + vol.Required(ATTR_NAME): cv.string, + vol.Required(ATTR_TYPE): cv.string, + vol.Required(ATTR_ID): int, + vol.Required(ATTR_ENDPOINT_ID): int, + } + ), + _cv_cluster_binding, +) @websocket_api.require_admin @@ -374,27 +422,13 @@ async def websocket_get_group( connection.send_result(msg[ID], group_info) -def cv_group_member(value: Any) -> GroupMember: - """Validate and transform a group member.""" - if not isinstance(value, Mapping): - raise vol.Invalid("Not a group member") - try: - group_member = GroupMember( - ieee=EUI64.convert(value["ieee"]), endpoint_id=value["endpoint_id"] - ) - except KeyError as err: - raise vol.Invalid("Not a group member") from err - - return group_member - - @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required(TYPE): "zha/group/add", vol.Required(GROUP_NAME): cv.string, vol.Optional(GROUP_ID): cv.positive_int, - vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]), + vol.Optional(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]), } ) @websocket_api.async_response @@ -441,7 +475,7 @@ async def websocket_remove_groups( { vol.Required(TYPE): "zha/group/members/add", vol.Required(GROUP_ID): cv.positive_int, - vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]), + vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]), } ) @websocket_api.async_response @@ -471,7 +505,7 @@ async def websocket_add_group_members( { vol.Required(TYPE): "zha/group/members/remove", vol.Required(GROUP_ID): cv.positive_int, - vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [cv_group_member]), + vol.Required(ATTR_MEMBERS): vol.All(cv.ensure_list, [GROUP_MEMBER_SCHEMA]), } ) @websocket_api.async_response @@ -837,30 +871,13 @@ async def websocket_unbind_devices( ) -def is_cluster_binding(value: Any) -> ClusterBinding: - """Validate and transform a cluster binding.""" - if not isinstance(value, Mapping): - raise vol.Invalid("Not a cluster binding") - try: - cluster_binding = ClusterBinding( - name=value["name"], - type=value["type"], - id=value["id"], - endpoint_id=value["endpoint_id"], - ) - except KeyError as err: - raise vol.Invalid("Not a cluster binding") from err - - return cluster_binding - - @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required(TYPE): "zha/groups/bind", vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA, vol.Required(GROUP_ID): cv.positive_int, - vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]), + vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]), } ) @websocket_api.async_response @@ -882,7 +899,7 @@ async def websocket_bind_group( vol.Required(TYPE): "zha/groups/unbind", vol.Required(ATTR_SOURCE_IEEE): IEEE_SCHEMA, vol.Required(GROUP_ID): cv.positive_int, - vol.Required(BINDINGS): vol.All(cv.ensure_list, [is_cluster_binding]), + vol.Required(BINDINGS): vol.All(cv.ensure_list, [CLUSTER_BINDING_SCHEMA]), } ) @websocket_api.async_response diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index 91ac905707f..79cc54c4829 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -8,7 +8,7 @@ from enum import Enum import logging import random import time -from typing import Any +from typing import TYPE_CHECKING, Any from zigpy import types import zigpy.exceptions @@ -75,6 +75,9 @@ from .const import ( ) from .helpers import LogMixin, async_get_zha_config_value +if TYPE_CHECKING: + from ..api import ClusterBinding + _LOGGER = logging.getLogger(__name__) _UPDATE_ALIVE_INTERVAL = (60, 90) _CHECKIN_GRACE_PERIODS = 2 @@ -655,7 +658,7 @@ class ZHADevice(LogMixin): ) return response - async def async_add_to_group(self, group_id): + async def async_add_to_group(self, group_id: int) -> None: """Add this device to the provided zigbee group.""" try: await self._zigpy_device.add_to_group(group_id) @@ -667,7 +670,7 @@ class ZHADevice(LogMixin): str(ex), ) - async def async_remove_from_group(self, group_id): + async def async_remove_from_group(self, group_id: int) -> None: """Remove this device from the provided zigbee group.""" try: await self._zigpy_device.remove_from_group(group_id) @@ -679,10 +682,12 @@ class ZHADevice(LogMixin): str(ex), ) - async def async_add_endpoint_to_group(self, endpoint_id, group_id): + async def async_add_endpoint_to_group( + self, endpoint_id: int, group_id: int + ) -> None: """Add the device endpoint to the provided zigbee group.""" try: - await self._zigpy_device.endpoints[int(endpoint_id)].add_to_group(group_id) + await self._zigpy_device.endpoints[endpoint_id].add_to_group(group_id) except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex: self.debug( "Failed to add endpoint: %s for device: '%s' to group: 0x%04x ex: %s", @@ -692,12 +697,12 @@ class ZHADevice(LogMixin): str(ex), ) - async def async_remove_endpoint_from_group(self, endpoint_id, group_id): + async def async_remove_endpoint_from_group( + self, endpoint_id: int, group_id: int + ) -> None: """Remove the device endpoint from the provided zigbee group.""" try: - await self._zigpy_device.endpoints[int(endpoint_id)].remove_from_group( - group_id - ) + await self._zigpy_device.endpoints[endpoint_id].remove_from_group(group_id) except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex: self.debug( "Failed to remove endpoint: %s for device '%s' from group: 0x%04x ex: %s", @@ -707,21 +712,28 @@ class ZHADevice(LogMixin): str(ex), ) - async def async_bind_to_group(self, group_id, cluster_bindings): + async def async_bind_to_group( + self, group_id: int, cluster_bindings: list[ClusterBinding] + ) -> None: """Directly bind this device to a group for the given clusters.""" await self._async_group_binding_operation( group_id, zdo_types.ZDOCmd.Bind_req, cluster_bindings ) - async def async_unbind_from_group(self, group_id, cluster_bindings): + async def async_unbind_from_group( + self, group_id: int, cluster_bindings: list[ClusterBinding] + ) -> None: """Unbind this device from a group for the given clusters.""" await self._async_group_binding_operation( group_id, zdo_types.ZDOCmd.Unbind_req, cluster_bindings ) async def _async_group_binding_operation( - self, group_id, operation, cluster_bindings - ): + self, + group_id: int, + operation: zdo_types.ZDOCmd, + cluster_bindings: list[ClusterBinding], + ) -> None: """Create or remove a direct zigbee binding between a device and a group.""" zdo = self._zigpy_device.zdo diff --git a/homeassistant/components/zha/core/group.py b/homeassistant/components/zha/core/group.py index 16202291860..93e96c7565b 100644 --- a/homeassistant/components/zha/core/group.py +++ b/homeassistant/components/zha/core/group.py @@ -4,11 +4,12 @@ from __future__ import annotations import asyncio import collections import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple import zigpy.endpoint import zigpy.exceptions import zigpy.group +from zigpy.types.named import EUI64 from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_registry import async_entries_for_device @@ -21,7 +22,14 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) -GroupMember = collections.namedtuple("GroupMember", "ieee endpoint_id") + +class GroupMember(NamedTuple): + """Describes a group member.""" + + ieee: EUI64 + endpoint_id: int + + GroupEntityReference = collections.namedtuple( "GroupEntityReference", "name original_name entity_id" )