From ffcc02e93dd210b86090d135b157f7b81982a583 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 18 Mar 2022 19:06:44 +0100 Subject: [PATCH] Add zha typing [api] (2) (#68335) --- homeassistant/components/zha/api.py | 210 +++++++++++++++------------- 1 file changed, 112 insertions(+), 98 deletions(-) diff --git a/homeassistant/components/zha/api.py b/homeassistant/components/zha/api.py index feebff87c8b..ac028597ea8 100644 --- a/homeassistant/components/zha/api.py +++ b/homeassistant/components/zha/api.py @@ -69,11 +69,13 @@ from .core.helpers import ( get_matched_clusters, qr_to_install_code, ) -from .core.typing import ZhaDeviceType, ZhaGatewayType +from .core.typing import ZhaDeviceType if TYPE_CHECKING: from homeassistant.components.websocket_api.connection import ActiveConnection + from .core.gateway import ZHAGateway + _LOGGER = logging.getLogger(__name__) TYPE = "type" @@ -210,9 +212,9 @@ async def websocket_permit_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Permit ZHA zigbee devices.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - duration = msg.get(ATTR_DURATION) - ieee = msg.get(ATTR_IEEE) + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + duration: int = msg[ATTR_DURATION] + ieee: EUI64 | None = msg.get(ATTR_IEEE) async def forward_messages(data): """Forward events to websocket.""" @@ -230,6 +232,8 @@ async def websocket_permit_devices( connection.subscriptions[msg["id"]] = async_cleanup zha_gateway.async_enable_debug_mode() + src_ieee: EUI64 + code: bytes if ATTR_SOURCE_IEEE in msg: src_ieee = msg[ATTR_SOURCE_IEEE] code = msg[ATTR_INSTALL_CODE] @@ -255,10 +259,8 @@ async def websocket_get_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA devices.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] devices = [device.zha_device_info for device in zha_gateway.devices.values()] - connection.send_result(msg[ID], devices) @@ -269,7 +271,7 @@ async def websocket_get_groupable_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA devices that can be grouped.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] devices = [device for device in zha_gateway.devices.values() if device.is_groupable] groupable_devices = [] @@ -309,7 +311,7 @@ async def websocket_get_groups( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA groups.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] groups = [group.group_info for group in zha_gateway.groups.values()] connection.send_result(msg[ID], groups) @@ -326,8 +328,8 @@ async def websocket_get_device( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA devices.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ieee = msg[ATTR_IEEE] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = msg[ATTR_IEEE] device = None if ieee in zha_gateway.devices: device = zha_gateway.devices[ieee].zha_device_info @@ -353,8 +355,8 @@ async def websocket_get_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA group.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - group_id = msg[GROUP_ID] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + group_id: int = msg[GROUP_ID] group = None if group_id in zha_gateway.groups: @@ -397,10 +399,10 @@ async def websocket_add_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Add a new ZHA group.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - group_name = msg[GROUP_NAME] - members = msg.get(ATTR_MEMBERS) - group_id = msg.get(GROUP_ID) + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + group_name: str = msg[GROUP_NAME] + group_id: int | None = msg.get(GROUP_ID) + members: list[GroupMember] | None = msg.get(ATTR_MEMBERS) group = await zha_gateway.async_create_zigpy_group(group_name, members, group_id) connection.send_result(msg[ID], group.group_info) @@ -417,8 +419,8 @@ async def websocket_remove_groups( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Remove the specified ZHA groups.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - group_ids = msg[GROUP_IDS] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + group_ids: list[int] = msg[GROUP_IDS] if len(group_ids) > 1: tasks = [] @@ -444,9 +446,9 @@ async def websocket_add_group_members( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Add members to a ZHA group.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - group_id = msg[GROUP_ID] - members = msg[ATTR_MEMBERS] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + group_id: int = msg[GROUP_ID] + members: list[GroupMember] = msg[ATTR_MEMBERS] zha_group = None if group_id in zha_gateway.groups: @@ -476,9 +478,9 @@ async def websocket_remove_group_members( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Remove members from a ZHA group.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - group_id = msg[GROUP_ID] - members = msg[ATTR_MEMBERS] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + group_id: int = msg[GROUP_ID] + members: list[GroupMember] = msg[ATTR_MEMBERS] zha_group = None if group_id in zha_gateway.groups: @@ -507,8 +509,8 @@ async def websocket_reconfigure_node( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Reconfigure a ZHA nodes entities by its ieee address.""" - zha_gateway: ZhaGatewayType = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ieee = msg[ATTR_IEEE] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = msg[ATTR_IEEE] device: ZhaDeviceType = zha_gateway.get_device(ieee) async def forward_messages(data): @@ -541,21 +543,24 @@ async def websocket_update_topology( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Update the ZHA network topology.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] hass.async_create_task(zha_gateway.application_controller.topology.scan()) @websocket_api.require_admin @websocket_api.websocket_command( - {vol.Required(TYPE): "zha/devices/clusters", vol.Required(ATTR_IEEE): EUI64.convert} + { + vol.Required(TYPE): "zha/devices/clusters", + vol.Required(ATTR_IEEE): EUI64.convert, + } ) @websocket_api.async_response async def websocket_device_clusters( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Return a list of device clusters.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ieee = msg[ATTR_IEEE] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = msg[ATTR_IEEE] zha_device = zha_gateway.get_device(ieee) response_clusters = [] if zha_device is not None: @@ -598,12 +603,12 @@ async def websocket_device_cluster_attributes( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Return a list of cluster attributes.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ieee = msg[ATTR_IEEE] - endpoint_id = msg[ATTR_ENDPOINT_ID] - cluster_id = msg[ATTR_CLUSTER_ID] - cluster_type = msg[ATTR_CLUSTER_TYPE] - cluster_attributes = [] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = msg[ATTR_IEEE] + endpoint_id: int = msg[ATTR_ENDPOINT_ID] + cluster_id: int = msg[ATTR_CLUSTER_ID] + cluster_type: str = msg[ATTR_CLUSTER_TYPE] + cluster_attributes: list[dict[str, Any]] = [] zha_device = zha_gateway.get_device(ieee) attributes = None if zha_device is not None: @@ -645,11 +650,11 @@ async def websocket_device_cluster_commands( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Return a list of cluster commands.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - cluster_id = msg[ATTR_CLUSTER_ID] - cluster_type = msg[ATTR_CLUSTER_TYPE] - ieee = msg[ATTR_IEEE] - endpoint_id = msg[ATTR_ENDPOINT_ID] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = msg[ATTR_IEEE] + endpoint_id: int = msg[ATTR_ENDPOINT_ID] + cluster_id: int = msg[ATTR_CLUSTER_ID] + cluster_type: str = msg[ATTR_CLUSTER_TYPE] zha_device = zha_gateway.get_device(ieee) cluster_commands = [] commands = None @@ -707,13 +712,13 @@ async def websocket_read_zigbee_cluster_attributes( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Read zigbee attribute for cluster on zha entity.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ieee = msg[ATTR_IEEE] - endpoint_id = msg[ATTR_ENDPOINT_ID] - cluster_id = msg[ATTR_CLUSTER_ID] - cluster_type = msg[ATTR_CLUSTER_TYPE] - attribute = msg[ATTR_ATTRIBUTE] - manufacturer = msg.get(ATTR_MANUFACTURER) or None + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = msg[ATTR_IEEE] + endpoint_id: int = msg[ATTR_ENDPOINT_ID] + cluster_id: int = msg[ATTR_CLUSTER_ID] + cluster_type: str = msg[ATTR_CLUSTER_TYPE] + attribute: int = msg[ATTR_ATTRIBUTE] + manufacturer: Any | None = msg.get(ATTR_MANUFACTURER) zha_device = zha_gateway.get_device(ieee) if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: manufacturer = zha_device.manufacturer_code @@ -747,15 +752,18 @@ async def websocket_read_zigbee_cluster_attributes( @websocket_api.require_admin @websocket_api.websocket_command( - {vol.Required(TYPE): "zha/devices/bindable", vol.Required(ATTR_IEEE): EUI64.convert} + { + vol.Required(TYPE): "zha/devices/bindable", + vol.Required(ATTR_IEEE): EUI64.convert, + } ) @websocket_api.async_response async def websocket_get_bindable_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Directly bind devices.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - source_ieee = msg[ATTR_IEEE] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + source_ieee: EUI64 = msg[ATTR_IEEE] source_device = zha_gateway.get_device(source_ieee) devices = [ @@ -788,9 +796,9 @@ async def websocket_bind_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Directly bind devices.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - source_ieee = msg[ATTR_SOURCE_IEEE] - target_ieee = msg[ATTR_TARGET_IEEE] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] + target_ieee: EUI64 = msg[ATTR_TARGET_IEEE] await async_binding_operation( zha_gateway, source_ieee, target_ieee, zdo_types.ZDOCmd.Bind_req ) @@ -816,9 +824,9 @@ async def websocket_unbind_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Remove a direct binding between devices.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - source_ieee = msg[ATTR_SOURCE_IEEE] - target_ieee = msg[ATTR_TARGET_IEEE] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] + target_ieee: EUI64 = msg[ATTR_TARGET_IEEE] await async_binding_operation( zha_gateway, source_ieee, target_ieee, zdo_types.ZDOCmd.Unbind_req ) @@ -862,12 +870,11 @@ async def websocket_bind_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Directly bind a device to a group.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - source_ieee = msg[ATTR_SOURCE_IEEE] - group_id = msg[GROUP_ID] - bindings = msg[BINDINGS] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] + group_id: int = msg[GROUP_ID] + bindings: list[ClusterBinding] = msg[BINDINGS] source_device = zha_gateway.get_device(source_ieee) - await source_device.async_bind_to_group(group_id, bindings) @@ -885,15 +892,20 @@ async def websocket_unbind_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Unbind a device from a group.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - source_ieee = msg[ATTR_SOURCE_IEEE] - group_id = msg[GROUP_ID] - bindings = msg[BINDINGS] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] + group_id: int = msg[GROUP_ID] + bindings: list[ClusterBinding] = msg[BINDINGS] source_device = zha_gateway.get_device(source_ieee) await source_device.async_unbind_from_group(group_id, bindings) -async def async_binding_operation(zha_gateway, source_ieee, target_ieee, operation): +async def async_binding_operation( + zha_gateway: ZHAGateway, + source_ieee: EUI64, + target_ieee: EUI64, + operation: zdo_types.ZDOCmd, +) -> None: """Create or remove a direct zigbee binding between 2 devices.""" source_device = zha_gateway.get_device(source_ieee) @@ -982,7 +994,7 @@ async def websocket_update_zha_configuration( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Update the ZHA configuration.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] options = zha_gateway.config_entry.options data_to_save = {**options, **{CUSTOM_CONFIGURATION: msg["data"]}} @@ -1002,13 +1014,15 @@ async def websocket_update_zha_configuration( @callback def async_load_api(hass: HomeAssistant) -> None: """Set up the web socket API.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] application_controller = zha_gateway.application_controller async def permit(service: ServiceCall) -> None: """Allow devices to join this network.""" - duration = service.data[ATTR_DURATION] - ieee = service.data.get(ATTR_IEEE) + duration: int = service.data[ATTR_DURATION] + ieee: EUI64 | None = service.data.get(ATTR_IEEE) + src_ieee: EUI64 + code: bytes if ATTR_SOURCE_IEEE in service.data: src_ieee = service.data[ATTR_SOURCE_IEEE] code = service.data[ATTR_INSTALL_CODE] @@ -1038,8 +1052,8 @@ def async_load_api(hass: HomeAssistant) -> None: async def remove(service: ServiceCall) -> None: """Remove a node from the network.""" - ieee = service.data[ATTR_IEEE] - zha_gateway: ZhaGatewayType = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + ieee: EUI64 = service.data[ATTR_IEEE] zha_device: ZhaDeviceType = zha_gateway.get_device(ieee) if zha_device is not None and ( zha_device.is_coordinator @@ -1056,13 +1070,13 @@ def async_load_api(hass: HomeAssistant) -> None: async def set_zigbee_cluster_attributes(service: ServiceCall) -> None: """Set zigbee attribute for cluster on zha entity.""" - ieee = service.data.get(ATTR_IEEE) - endpoint_id = service.data.get(ATTR_ENDPOINT_ID) - cluster_id = service.data.get(ATTR_CLUSTER_ID) - cluster_type = service.data.get(ATTR_CLUSTER_TYPE) - attribute = service.data.get(ATTR_ATTRIBUTE) - value = service.data.get(ATTR_VALUE) - manufacturer = service.data.get(ATTR_MANUFACTURER) or None + ieee: EUI64 = service.data[ATTR_IEEE] + endpoint_id: int = service.data[ATTR_ENDPOINT_ID] + cluster_id: int = service.data[ATTR_CLUSTER_ID] + cluster_type: str = service.data[ATTR_CLUSTER_TYPE] + attribute: int | str = service.data[ATTR_ATTRIBUTE] + value: int | bool | str = service.data[ATTR_VALUE] + manufacturer: int | None = service.data.get(ATTR_MANUFACTURER) zha_device = zha_gateway.get_device(ieee) if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: manufacturer = zha_device.manufacturer_code @@ -1104,14 +1118,14 @@ def async_load_api(hass: HomeAssistant) -> None: async def issue_zigbee_cluster_command(service: ServiceCall) -> None: """Issue command on zigbee cluster on zha entity.""" - ieee = service.data.get(ATTR_IEEE) - endpoint_id = service.data.get(ATTR_ENDPOINT_ID) - cluster_id = service.data.get(ATTR_CLUSTER_ID) - cluster_type = service.data.get(ATTR_CLUSTER_TYPE) - command = service.data.get(ATTR_COMMAND) - command_type = service.data.get(ATTR_COMMAND_TYPE) - args = service.data.get(ATTR_ARGS) - manufacturer = service.data.get(ATTR_MANUFACTURER) or None + ieee: EUI64 = service.data[ATTR_IEEE] + endpoint_id: int = service.data[ATTR_ENDPOINT_ID] + cluster_id: int = service.data[ATTR_CLUSTER_ID] + cluster_type: str = service.data[ATTR_CLUSTER_TYPE] + command: int = service.data[ATTR_COMMAND] + command_type: str = service.data[ATTR_COMMAND_TYPE] + args: list = service.data[ATTR_ARGS] + manufacturer: int | None = service.data.get(ATTR_MANUFACTURER) zha_device = zha_gateway.get_device(ieee) if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: manufacturer = zha_device.manufacturer_code @@ -1156,11 +1170,11 @@ def async_load_api(hass: HomeAssistant) -> None: async def issue_zigbee_group_command(service: ServiceCall) -> None: """Issue command on zigbee cluster on a zigbee group.""" - group_id = service.data.get(ATTR_GROUP) - cluster_id = service.data.get(ATTR_CLUSTER_ID) - command = service.data.get(ATTR_COMMAND) - args = service.data.get(ATTR_ARGS) - manufacturer = service.data.get(ATTR_MANUFACTURER) or None + group_id: int = service.data[ATTR_GROUP] + cluster_id: int = service.data[ATTR_CLUSTER_ID] + command: int = service.data[ATTR_COMMAND] + args: list = service.data[ATTR_ARGS] + manufacturer: int | None = service.data.get(ATTR_MANUFACTURER) group = zha_gateway.get_group(group_id) if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: _LOGGER.error("Missing manufacturer attribute for cluster: %d", cluster_id) @@ -1203,10 +1217,10 @@ def async_load_api(hass: HomeAssistant) -> None: async def warning_device_squawk(service: ServiceCall) -> None: """Issue the squawk command for an IAS warning device.""" - ieee = service.data[ATTR_IEEE] - mode = service.data.get(ATTR_WARNING_DEVICE_MODE) - strobe = service.data.get(ATTR_WARNING_DEVICE_STROBE) - level = service.data.get(ATTR_LEVEL) + ieee: EUI64 = service.data[ATTR_IEEE] + mode: int = service.data[ATTR_WARNING_DEVICE_MODE] + strobe: int = service.data[ATTR_WARNING_DEVICE_STROBE] + level: int = service.data[ATTR_LEVEL] if (zha_device := zha_gateway.get_device(ieee)) is not None: if channel := _get_ias_wd_channel(zha_device):