From f4a4c6bea5d3d6908939dd2a6480e8a7163bceb1 Mon Sep 17 00:00:00 2001 From: "David F. Mulcahey" Date: Wed, 29 Jan 2020 12:24:43 -0500 Subject: [PATCH] ZHA group and device cleanup (#31260) * add dispatching of groups to light * added ha device registry device id * added zha group object * add group event listener * add and remove group members * get group by name * api cleanup * clean up get device info * create and remove zigpy groups * clean up create and remove group api * use device id * use device id * cleanup * update test * update tests to allow group events to flow --- homeassistant/components/zha/api.py | 173 ++++--------------- homeassistant/components/zha/core/const.py | 13 +- homeassistant/components/zha/core/device.py | 29 ++++ homeassistant/components/zha/core/gateway.py | 129 ++++++++++++-- homeassistant/components/zha/core/group.py | 95 ++++++++++ homeassistant/components/zha/core/helpers.py | 34 +--- tests/components/zha/conftest.py | 5 +- 7 files changed, 289 insertions(+), 189 deletions(-) create mode 100644 homeassistant/components/zha/core/group.py diff --git a/homeassistant/components/zha/api.py b/homeassistant/components/zha/api.py index ac88b7c1179..fe628d90e90 100644 --- a/homeassistant/components/zha/api.py +++ b/homeassistant/components/zha/api.py @@ -12,7 +12,6 @@ import zigpy.zdo.types as zdo_types from homeassistant.components import websocket_api from homeassistant.core import callback import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.device_registry import async_get_registry from homeassistant.helpers.dispatcher import async_dispatcher_connect from .core.const import ( @@ -53,11 +52,7 @@ from .core.const import ( WARNING_DEVICE_STROBE_HIGH, WARNING_DEVICE_STROBE_YES, ) -from .core.helpers import ( - async_get_device_info, - async_is_bindable_target, - get_matched_clusters, -) +from .core.helpers import async_is_bindable_target, get_matched_clusters _LOGGER = logging.getLogger(__name__) @@ -212,13 +207,9 @@ async def websocket_permit_devices(hass, connection, msg): async def websocket_get_devices(hass, connection, msg): """Get ZHA devices.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) - devices = [] - for device in zha_gateway.devices.values(): - devices.append( - async_get_device_info(hass, device, ha_device_registry=ha_device_registry) - ) + devices = [device.async_get_info() for device in zha_gateway.devices.values()] + connection.send_result(msg[ID], devices) @@ -228,16 +219,13 @@ async def websocket_get_devices(hass, connection, msg): async def websocket_get_groupable_devices(hass, connection, msg): """Get ZHA devices that can be grouped.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) - devices = [] - for device in zha_gateway.devices.values(): - if device.is_groupable: - devices.append( - async_get_device_info( - hass, device, ha_device_registry=ha_device_registry - ) - ) + devices = [ + device.async_get_info() + for device in zha_gateway.devices.values() + if device.is_groupable or device.is_coordinator + ] + connection.send_result(msg[ID], devices) @@ -246,7 +234,8 @@ async def websocket_get_groupable_devices(hass, connection, msg): @websocket_api.websocket_command({vol.Required(TYPE): "zha/groups"}) async def websocket_get_groups(hass, connection, msg): """Get ZHA groups.""" - groups = await get_groups(hass) + zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + groups = [group.async_get_info() for group in zha_gateway.groups.values()] connection.send_result(msg[ID], groups) @@ -258,13 +247,10 @@ async def websocket_get_groups(hass, connection, msg): async def websocket_get_device(hass, connection, msg): """Get ZHA devices.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) ieee = msg[ATTR_IEEE] device = None if ieee in zha_gateway.devices: - device = async_get_device_info( - hass, zha_gateway.devices[ieee], ha_device_registry=ha_device_registry - ) + device = zha_gateway.devices[ieee].async_get_info() if not device: connection.send_message( websocket_api.error_message( @@ -283,17 +269,11 @@ async def websocket_get_device(hass, connection, msg): async def websocket_get_group(hass, connection, msg): """Get ZHA group.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) group_id = msg[GROUP_ID] group = None - if group_id in zha_gateway.application_controller.groups: - group = async_get_group_info( - hass, - zha_gateway, - zha_gateway.application_controller.groups[group_id], - ha_device_registry, - ) + if group_id in zha_gateway.groups: + group = zha_gateway.groups.get(group_id).async_get_info() if not group: connection.send_message( websocket_api.error_message( @@ -316,28 +296,10 @@ async def websocket_get_group(hass, connection, msg): async def websocket_add_group(hass, connection, msg): """Add a new ZHA group.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) group_name = msg[GROUP_NAME] - zigpy_group = async_get_group_by_name(zha_gateway, group_name) - ret_group = None members = msg.get(ATTR_MEMBERS) - # we start with one to fill any gaps from a user removing existing groups - group_id = 1 - while group_id in zha_gateway.application_controller.groups: - group_id += 1 - - # guard against group already existing - if zigpy_group is None: - zigpy_group = zha_gateway.application_controller.groups.add_group( - group_id, group_name - ) - if members is not None: - tasks = [] - for ieee in members: - tasks.append(zha_gateway.devices[ieee].async_add_to_group(group_id)) - await asyncio.gather(*tasks) - ret_group = async_get_group_info(hass, zha_gateway, zigpy_group, ha_device_registry) - connection.send_result(msg[ID], ret_group) + group = await zha_gateway.async_create_zigpy_group(group_name, members) + connection.send_result(msg[ID], group.async_get_info()) @websocket_api.require_admin @@ -351,17 +313,16 @@ async def websocket_add_group(hass, connection, msg): async def websocket_remove_groups(hass, connection, msg): """Remove the specified ZHA groups.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - groups = zha_gateway.application_controller.groups group_ids = msg[GROUP_IDS] if len(group_ids) > 1: tasks = [] for group_id in group_ids: - tasks.append(remove_group(groups[group_id], zha_gateway)) + tasks.append(zha_gateway.async_remove_zigpy_group(group_id)) await asyncio.gather(*tasks) else: - await remove_group(groups[group_ids[0]], zha_gateway) - ret_groups = await get_groups(hass) + await zha_gateway.async_remove_zigpy_group(group_ids[0]) + ret_groups = [group.async_get_info() for group in zha_gateway.groups.values()] connection.send_result(msg[ID], ret_groups) @@ -377,25 +338,21 @@ async def websocket_remove_groups(hass, connection, msg): async def websocket_add_group_members(hass, connection, msg): """Add members to a ZHA group.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) group_id = msg[GROUP_ID] members = msg[ATTR_MEMBERS] - zigpy_group = None + zha_group = None - if group_id in zha_gateway.application_controller.groups: - zigpy_group = zha_gateway.application_controller.groups[group_id] - tasks = [] - for ieee in members: - tasks.append(zha_gateway.devices[ieee].async_add_to_group(group_id)) - await asyncio.gather(*tasks) - if not zigpy_group: + if group_id in zha_gateway.groups: + zha_group = zha_gateway.groups.get(group_id) + await zha_group.async_add_members(members) + if not zha_group: connection.send_message( websocket_api.error_message( msg[ID], websocket_api.const.ERR_NOT_FOUND, "ZHA Group not found" ) ) return - ret_group = async_get_group_info(hass, zha_gateway, zigpy_group, ha_device_registry) + ret_group = zha_group.async_get_info() connection.send_result(msg[ID], ret_group) @@ -411,88 +368,24 @@ async def websocket_add_group_members(hass, connection, msg): async def websocket_remove_group_members(hass, connection, msg): """Remove members from a ZHA group.""" zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) group_id = msg[GROUP_ID] members = msg[ATTR_MEMBERS] - zigpy_group = None + zha_group = None - if group_id in zha_gateway.application_controller.groups: - zigpy_group = zha_gateway.application_controller.groups[group_id] - tasks = [] - for ieee in members: - tasks.append(zha_gateway.devices[ieee].async_remove_from_group(group_id)) - await asyncio.gather(*tasks) - if not zigpy_group: + if group_id in zha_gateway.groups: + zha_group = zha_gateway.groups.get(group_id) + await zha_group.async_remove_members(members) + if not zha_group: connection.send_message( websocket_api.error_message( msg[ID], websocket_api.const.ERR_NOT_FOUND, "ZHA Group not found" ) ) return - ret_group = async_get_group_info(hass, zha_gateway, zigpy_group, ha_device_registry) + ret_group = zha_group.async_get_info() connection.send_result(msg[ID], ret_group) -async def get_groups(hass,): - """Get ZHA Groups.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ha_device_registry = await async_get_registry(hass) - - groups = [] - for group in zha_gateway.application_controller.groups.values(): - groups.append( - async_get_group_info(hass, zha_gateway, group, ha_device_registry) - ) - return groups - - -async def remove_group(group, zha_gateway): - """Remove ZHA Group.""" - if group.members: - tasks = [] - for member_ieee in group.members.keys(): - if member_ieee[0] in zha_gateway.devices: - tasks.append( - zha_gateway.devices[member_ieee[0]].async_remove_from_group( - group.group_id - ) - ) - if tasks: - await asyncio.gather(*tasks) - else: - # we have members but none are tracked by ZHA for whatever reason - zha_gateway.application_controller.groups.pop(group.group_id) - else: - zha_gateway.application_controller.groups.pop(group.group_id) - - -@callback -def async_get_group_info(hass, zha_gateway, group, ha_device_registry): - """Get ZHA group.""" - ret_group = {} - ret_group["group_id"] = group.group_id - ret_group["name"] = group.name - ret_group["members"] = [ - async_get_device_info( - hass, - zha_gateway.get_device(member_ieee[0]), - ha_device_registry=ha_device_registry, - ) - for member_ieee in group.members.keys() - if member_ieee[0] in zha_gateway.devices - ] - return ret_group - - -@callback -def async_get_group_by_name(zha_gateway, group_name): - """Get ZHA group by name.""" - for group in zha_gateway.application_controller.groups.values(): - if group.name == group_name: - return group - return None - - @websocket_api.require_admin @websocket_api.async_response @websocket_api.websocket_command( @@ -712,9 +605,9 @@ async def websocket_get_bindable_devices(hass, connection, msg): zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] source_ieee = msg[ATTR_IEEE] source_device = zha_gateway.get_device(source_ieee) - ha_device_registry = await async_get_registry(hass) + devices = [ - async_get_device_info(hass, device, ha_device_registry=ha_device_registry) + device.async_get_info() for device in zha_gateway.devices.values() if async_is_bindable_target(source_device, device) ] diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index 3fbb62f8433..b8782101cd4 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -225,13 +225,18 @@ WARNING_DEVICE_SQUAWK_MODE_ARMED = 0 WARNING_DEVICE_SQUAWK_MODE_DISARMED = 1 ZHA_DISCOVERY_NEW = "zha_discovery_new_{}" -ZHA_GW_MSG_RAW_INIT = "raw_device_initialized" ZHA_GW_MSG = "zha_gateway_message" -ZHA_GW_MSG_DEVICE_REMOVED = "device_removed" -ZHA_GW_MSG_DEVICE_INFO = "device_info" ZHA_GW_MSG_DEVICE_FULL_INIT = "device_fully_initialized" +ZHA_GW_MSG_DEVICE_INFO = "device_info" ZHA_GW_MSG_DEVICE_JOINED = "device_joined" -ZHA_GW_MSG_LOG_OUTPUT = "log_output" +ZHA_GW_MSG_DEVICE_REMOVED = "device_removed" +ZHA_GW_MSG_GROUP_ADDED = "group_added" +ZHA_GW_MSG_GROUP_INFO = "group_info" +ZHA_GW_MSG_GROUP_MEMBER_ADDED = "group_member_added" +ZHA_GW_MSG_GROUP_MEMBER_REMOVED = "group_member_removed" +ZHA_GW_MSG_GROUP_REMOVED = "group_removed" ZHA_GW_MSG_LOG_ENTRY = "log_entry" +ZHA_GW_MSG_LOG_OUTPUT = "log_output" +ZHA_GW_MSG_RAW_INIT = "raw_device_initialized" ZHA_GW_RADIO = "radio" ZHA_GW_RADIO_DESCRIPTION = "radio_description" diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index 3ed44a8f2aa..8810fd77fe7 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -104,8 +104,18 @@ class ZHADevice(LogMixin): self._available_check = async_track_time_interval( self.hass, self._check_available, _UPDATE_ALIVE_INTERVAL ) + self._ha_device_id = None self.status = DeviceStatus.CREATED + @property + def device_id(self): + """Return the HA device registry device id.""" + return self._ha_device_id + + def set_device_id(self, device_id): + """Set the HA device registry device id.""" + self._ha_device_id = device_id + @property def name(self): """Return device name.""" @@ -406,6 +416,25 @@ class ZHADevice(LogMixin): """Set last seen on the zigpy device.""" self._zigpy_device.last_seen = last_seen + @callback + def async_get_info(self): + """Get ZHA device information.""" + device_info = {} + device_info.update(self.device_info) + device_info["entities"] = [ + { + "entity_id": entity_ref.reference_id, + ATTR_NAME: entity_ref.device_info[ATTR_NAME], + } + for entity_ref in self.gateway.device_registry[self.ieee] + ] + reg_device = self.gateway.ha_device_registry.async_get(self.device_id) + if reg_device is not None: + device_info["user_given_name"] = reg_device.name_by_user + device_info["device_reg_id"] = reg_device.id + device_info["area_id"] = reg_device.area_id + return device_info + @callback def async_get_clusters(self): """Get all clusters for this device.""" diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 106b77d6602..9456b8e9088 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -21,6 +21,7 @@ from homeassistant.helpers.device_registry import ( async_get_registry as get_dev_reg, ) from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.entity_registry import async_get_registry as get_ent_reg from .const import ( ATTR_IEEE, @@ -58,6 +59,11 @@ from .const import ( ZHA_GW_MSG_DEVICE_INFO, ZHA_GW_MSG_DEVICE_JOINED, ZHA_GW_MSG_DEVICE_REMOVED, + ZHA_GW_MSG_GROUP_ADDED, + ZHA_GW_MSG_GROUP_INFO, + ZHA_GW_MSG_GROUP_MEMBER_ADDED, + ZHA_GW_MSG_GROUP_MEMBER_REMOVED, + ZHA_GW_MSG_GROUP_REMOVED, ZHA_GW_MSG_LOG_ENTRY, ZHA_GW_MSG_LOG_OUTPUT, ZHA_GW_MSG_RAW_INIT, @@ -66,7 +72,7 @@ from .const import ( ) from .device import DeviceStatus, ZHADevice from .discovery import async_dispatch_discovery_info, async_process_endpoint -from .helpers import async_get_device_info +from .group import ZHAGroup from .patches import apply_application_controller_patch from .registries import RADIO_TYPES from .store import async_get_registry @@ -87,9 +93,11 @@ class ZHAGateway: self._hass = hass self._config = config self._devices = {} + self._groups = {} self._device_registry = collections.defaultdict(list) self.zha_storage = None self.ha_device_registry = None + self.ha_entity_registry = None self.application_controller = None self.radio_description = None hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self @@ -105,6 +113,7 @@ class ZHAGateway: """Initialize controller and connect radio.""" self.zha_storage = await async_get_registry(self._hass) self.ha_device_registry = await get_dev_reg(self._hass) + self.ha_entity_registry = await get_ent_reg(self._hass) usb_path = self._config_entry.data.get(CONF_USB_PATH) baudrate = self._config.get(CONF_BAUDRATE, DEFAULT_BAUDRATE) @@ -123,6 +132,7 @@ class ZHAGateway: self.application_controller = radio_details[CONTROLLER](radio, database) apply_application_controller_patch(self) self.application_controller.add_listener(self) + self.application_controller.groups.add_listener(self) await self.application_controller.startup(auto_form=True) self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str( self.application_controller.ieee @@ -142,6 +152,8 @@ class ZHAGateway: ) await asyncio.gather(*init_tasks) + self._initialize_groups() + def device_joined(self, device): """Handle device joined. @@ -182,15 +194,53 @@ class ZHAGateway: """Handle device leaving the network.""" self.async_update_device(device, False) + def group_member_removed(self, zigpy_group, endpoint): + """Handle zigpy group member removed event.""" + # need to handle endpoint correctly on groups + zha_group = self._async_get_or_create_group(zigpy_group) + zha_group.info("group_member_removed - endpoint: %s", endpoint) + self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) + + def group_member_added(self, zigpy_group, endpoint): + """Handle zigpy group member added event.""" + # need to handle endpoint correctly on groups + zha_group = self._async_get_or_create_group(zigpy_group) + zha_group.info("group_member_added - endpoint: %s", endpoint) + self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) + + def group_added(self, zigpy_group): + """Handle zigpy group added event.""" + zha_group = self._async_get_or_create_group(zigpy_group) + zha_group.info("group_added") + # need to dispatch for entity creation here + self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) + + def group_removed(self, zigpy_group): + """Handle zigpy group added event.""" + self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) + zha_group = self._groups.pop(zigpy_group.group_id, None) + zha_group.info("group_removed") + + def _send_group_gateway_message(self, zigpy_group, gateway_message_type): + """Send the gareway event for a zigpy group event.""" + zha_group = self._groups.get(zigpy_group.group_id, None) + if zha_group is not None: + async_dispatcher_send( + self._hass, + ZHA_GW_MSG, + { + ATTR_TYPE: gateway_message_type, + ZHA_GW_MSG_GROUP_INFO: zha_group.async_get_info(), + }, + ) + async def _async_remove_device(self, device, entity_refs): if entity_refs is not None: remove_tasks = [] for entity_ref in entity_refs: remove_tasks.append(entity_ref.remove_future) await asyncio.wait(remove_tasks) - reg_device = self.ha_device_registry.async_get_device( - {(DOMAIN, str(device.ieee))}, set() - ) + reg_device = self.ha_device_registry.async_get(device.device_id) if reg_device is not None: self.ha_device_registry.async_remove_device(reg_device.id) @@ -199,7 +249,7 @@ class ZHAGateway: zha_device = self._devices.pop(device.ieee, None) entity_refs = self._device_registry.pop(device.ieee, None) if zha_device is not None: - device_info = async_get_device_info(self._hass, zha_device) + device_info = zha_device.async_get_info() zha_device.async_unsub_dispatcher() async_dispatcher_send( self._hass, "{}_{}".format(SIGNAL_REMOVE, str(zha_device.ieee)) @@ -221,7 +271,14 @@ class ZHAGateway: def get_group(self, group_id): """Return Group for given group id.""" - return self.application_controller.groups[group_id] + return self.groups.get(group_id) + + def async_get_group_by_name(self, group_name): + """Get ZHA group by name.""" + for group in self.groups.values(): + if group.name == group_name: + return group + return None def get_entity_reference(self, entity_id): """Return entity reference for given entity_id if found.""" @@ -244,6 +301,11 @@ class ZHAGateway: """Return devices.""" return self._devices + @property + def groups(self): + """Return groups.""" + return self._groups + @property def device_registry(self): """Return entities by ieee.""" @@ -290,6 +352,12 @@ class ZHAGateway: logging.getLogger(logger_name).removeHandler(self._log_relay_handler) self.debug_enabled = False + def _initialize_groups(self): + """Initialize ZHA groups.""" + for group_id in self.application_controller.groups: + group = self.application_controller.groups[group_id] + self._async_get_or_create_group(group) + @callback def _async_get_or_create_device(self, zigpy_device): """Get or create a ZHA device.""" @@ -297,7 +365,7 @@ class ZHAGateway: if zha_device is None: zha_device = ZHADevice(self._hass, zigpy_device, self) self._devices[zigpy_device.ieee] = zha_device - self.ha_device_registry.async_get_or_create( + device_registry_device = self.ha_device_registry.async_get_or_create( config_entry_id=self._config_entry.entry_id, connections={(CONNECTION_ZIGBEE, str(zha_device.ieee))}, identifiers={(DOMAIN, str(zha_device.ieee))}, @@ -305,10 +373,20 @@ class ZHAGateway: manufacturer=zha_device.manufacturer, model=zha_device.model, ) + zha_device.set_device_id(device_registry_device.id) entry = self.zha_storage.async_get_or_create(zha_device) zha_device.async_update_last_seen(entry.last_seen) return zha_device + @callback + def _async_get_or_create_group(self, zigpy_group): + """Get or create a ZHA group.""" + zha_group = self._groups.get(zigpy_group.group_id) + if zha_group is None: + zha_group = ZHAGroup(self._hass, self, zigpy_group) + self._groups[zigpy_group.group_id] = zha_group + return zha_group + @callback def async_device_became_available( self, sender, profile, cluster, src_ep, dst_ep, message @@ -356,9 +434,8 @@ class ZHAGateway: ) await self._async_device_joined(device, zha_device) - device_info = async_get_device_info( - self._hass, zha_device, self.ha_device_registry - ) + device_info = zha_device.async_get_info() + async_dispatcher_send( self._hass, ZHA_GW_MSG, @@ -432,6 +509,38 @@ class ZHAGateway: # will cause async_init to fire so don't explicitly call it zha_device.update_available(True) + async def async_create_zigpy_group(self, name, members): + """Create a new Zigpy Zigbee group.""" + # we start with one to fill any gaps from a user removing existing groups + group_id = 1 + while group_id in self.groups: + group_id += 1 + + # guard against group already existing + if self.async_get_group_by_name(name) is None: + self.application_controller.groups.add_group(group_id, name) + if members is not None: + tasks = [] + for ieee in members: + tasks.append(self.devices[ieee].async_add_to_group(group_id)) + await asyncio.gather(*tasks) + return self.groups.get(group_id) + + async def async_remove_zigpy_group(self, group_id): + """Remove a Zigbee group from Zigpy.""" + group = self.groups.get(group_id) + if group and group.members: + tasks = [] + for member in group.members: + tasks.append(member.async_remove_from_group(group_id)) + if tasks: + await asyncio.gather(*tasks) + else: + # we have members but none are tracked by ZHA for whatever reason + self.application_controller.groups.pop(group_id) + else: + self.application_controller.groups.pop(group_id) + async def shutdown(self): """Stop ZHA Controller Application.""" _LOGGER.debug("Shutting down ZHA ControllerApplication") diff --git a/homeassistant/components/zha/core/group.py b/homeassistant/components/zha/core/group.py new file mode 100644 index 00000000000..92ce1f75360 --- /dev/null +++ b/homeassistant/components/zha/core/group.py @@ -0,0 +1,95 @@ +""" +Group for Zigbee Home Automation. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/integrations/zha/ +""" +import asyncio +import logging + +from homeassistant.core import callback + +from .helpers import LogMixin + +_LOGGER = logging.getLogger(__name__) + + +class ZHAGroup(LogMixin): + """ZHA Zigbee group object.""" + + def __init__(self, hass, zha_gateway, zigpy_group): + """Initialize the group.""" + self.hass = hass + self._zigpy_group = zigpy_group + self._zha_gateway = zha_gateway + + @property + def name(self): + """Return group name.""" + return self._zigpy_group.name + + @property + def group_id(self): + """Return group name.""" + return self._zigpy_group.group_id + + @property + def endpoint(self): + """Return the endpoint for this group.""" + return self._zigpy_group.endpoint + + @property + def members(self): + """Return the ZHA devices that are members of this group.""" + return [ + self._zha_gateway.devices.get(member_ieee[0]) + for member_ieee in self._zigpy_group.members.keys() + if member_ieee[0] in self._zha_gateway.devices + ] + + async def async_add_members(self, member_ieee_addresses): + """Add members to this group.""" + if len(member_ieee_addresses) > 1: + tasks = [] + for ieee in member_ieee_addresses: + tasks.append( + self._zha_gateway.devices[ieee].async_add_to_group(self.group_id) + ) + await asyncio.gather(*tasks) + else: + await self._zha_gateway.devices[ + member_ieee_addresses[0] + ].async_add_to_group(self.group_id) + + async def async_remove_members(self, member_ieee_addresses): + """Remove members from this group.""" + if len(member_ieee_addresses) > 1: + tasks = [] + for ieee in member_ieee_addresses: + tasks.append( + self._zha_gateway.devices[ieee].async_remove_from_group( + self.group_id + ) + ) + await asyncio.gather(*tasks) + else: + await self._zha_gateway.devices[ + member_ieee_addresses[0] + ].async_remove_from_group(self.group_id) + + @callback + def async_get_info(self): + """Get ZHA group info.""" + group_info = {} + group_info["group_id"] = self.group_id + group_info["name"] = self.name + group_info["members"] = [ + zha_device.async_get_info() for zha_device in self.members + ] + return group_info + + def log(self, level, msg, *args): + """Log a message.""" + msg = f"[%s](%s): {msg}" + args = (self.name, self.group_id) + args + _LOGGER.log(level, msg, *args) diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py index 981a03fe7b5..e3ff446ba98 100644 --- a/homeassistant/components/zha/core/helpers.py +++ b/homeassistant/components/zha/core/helpers.py @@ -11,14 +11,7 @@ import zigpy.types from homeassistant.core import callback -from .const import ( - ATTR_NAME, - CLUSTER_TYPE_IN, - CLUSTER_TYPE_OUT, - DATA_ZHA, - DATA_ZHA_GATEWAY, - DOMAIN, -) +from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, DATA_ZHA, DATA_ZHA_GATEWAY from .registries import BINDABLE_CLUSTERS _LOGGER = logging.getLogger(__name__) @@ -131,28 +124,3 @@ class LogMixin: def error(self, msg, *args): """Error level log.""" return self.log(logging.ERROR, msg, *args) - - -@callback -def async_get_device_info(hass, device, ha_device_registry=None): - """Get ZHA device.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - ret_device = {} - ret_device.update(device.device_info) - ret_device["entities"] = [ - { - "entity_id": entity_ref.reference_id, - ATTR_NAME: entity_ref.device_info[ATTR_NAME], - } - for entity_ref in zha_gateway.device_registry[device.ieee] - ] - - if ha_device_registry is not None: - reg_device = ha_device_registry.async_get_device( - {(DOMAIN, str(device.ieee))}, set() - ) - if reg_device is not None: - ret_device["user_given_name"] = reg_device.name_by_user - ret_device["device_reg_id"] = reg_device.id - ret_device["area_id"] = reg_device.area_id - return ret_device diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index 32e602c1431..18344172d29 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -50,9 +50,10 @@ async def zha_gateway_fixture(hass, config_entry): gateway.ha_device_registry = dev_reg gateway.application_controller = mock.MagicMock(spec_set=ControllerApplication) groups = zigpy.group.Groups(gateway.application_controller) - groups.listener_event = mock.MagicMock() + groups.add_listener(gateway) groups.add_group(FIXTURE_GRP_ID, FIXTURE_GRP_NAME, suppress_event=True) - gateway.application_controller.groups = groups + gateway.application_controller.configure_mock(groups=groups) + gateway._initialize_groups() return gateway