diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index 2eb567ab4c4..43b1634cba7 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -23,7 +23,6 @@ ATTR_COMMAND_TYPE = "command_type" ATTR_DEVICE_IEEE = "device_ieee" ATTR_DEVICE_TYPE = "device_type" ATTR_ENDPOINT_ID = "endpoint_id" -ATTR_ENTITY_DOMAIN = "entity_domain" ATTR_IEEE = "ieee" ATTR_LAST_SEEN = "last_seen" ATTR_LEVEL = "level" diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 7202fd869fa..19a83c3b6bc 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -6,6 +6,7 @@ from typing import Callable, List, Tuple from homeassistant import const as ha_const from homeassistant.core import callback +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity_registry import async_entries_for_device from homeassistant.helpers.typing import HomeAssistantType @@ -182,59 +183,48 @@ class GroupProbe: ) return - if group.entity_domain is None: - _LOGGER.debug( - "Group: %s:0x%04x has no user set entity domain - attempting entity domain discovery", - group.name, - group.group_id, - ) - group.entity_domain = GroupProbe.determine_default_entity_domain( - self._hass, group - ) + entity_domains = GroupProbe.determine_entity_domains(self._hass, group) - if group.entity_domain is None: + if not entity_domains: return - _LOGGER.debug( - "Group: %s:0x%04x has an entity domain of: %s after discovery", - group.name, - group.group_id, - group.entity_domain, - ) - zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] - entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(group.entity_domain) - if entity_class is None: - return - - self._hass.data[zha_const.DATA_ZHA][group.entity_domain].append( - ( - entity_class, + for domain in entity_domains: + entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain) + if entity_class is None: + continue + self._hass.data[zha_const.DATA_ZHA][domain].append( ( - group.domain_entity_ids, - f"{group.entity_domain}_group_{group.group_id}", - group.group_id, - zha_gateway.coordinator_zha_device, - ), + entity_class, + ( + group.get_domain_entity_ids(domain), + f"{domain}_group_{group.group_id}", + group.group_id, + zha_gateway.coordinator_zha_device, + ), + ) ) - ) + async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES) @staticmethod - def determine_default_entity_domain( + def determine_entity_domains( hass: HomeAssistantType, group: zha_typing.ZhaGroupType - ): - """Determine the default entity domain for this group.""" + ) -> List[str]: + """Determine the entity domains for this group.""" + entity_domains: List[str] = [] if len(group.members) < 2: _LOGGER.debug( "Group: %s:0x%04x has less than 2 members so cannot default an entity domain", group.name, group.group_id, ) - return None + return entity_domains zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] all_domain_occurrences = [] for device in group.members: + if device.is_coordinator: + continue entities = async_entries_for_device( zha_gateway.ha_entity_registry, device.device_id ) @@ -245,15 +235,18 @@ class GroupProbe: if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS ] ) + if not all_domain_occurrences: + return entity_domains + # get all domains we care about if there are more than 2 entities of this domain counts = Counter(all_domain_occurrences) - domain = counts.most_common(1)[0][0] + entity_domains = [domain[0] for domain in counts.items() if domain[1] >= 2] _LOGGER.debug( - "The default entity domain is: %s for group: %s:0x%04x", - domain, + "The entity domains are: %s for group: %s:0x%04x", + entity_domains, group.name, group.group_id, ) - return domain + return entity_domains PROBE = ProbeEndpoint() diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 9d5bf609ed2..bc7ff42d25f 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -445,8 +445,6 @@ class ZHAGateway: if zha_group is None: zha_group = ZHAGroup(self._hass, self, zigpy_group) self._groups[zigpy_group.group_id] = zha_group - group_entry = self.zha_storage.async_get_or_create_group(zha_group) - zha_group.entity_domain = group_entry.entity_domain return zha_group @callback @@ -469,8 +467,6 @@ class ZHAGateway: """Update the devices in the store.""" for device in self.devices.values(): self.zha_storage.async_update_device(device) - for group in self.groups.values(): - self.zha_storage.async_update_group(group) await self.zha_storage.async_save() async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType): @@ -559,9 +555,7 @@ class ZHAGateway: zha_group.group_id, ) discovery.GROUP_PROBE.discover_group_entities(zha_group) - if zha_group.entity_domain is not None: - self.zha_storage.async_update_group(zha_group) - async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) + return zha_group async def async_remove_zigpy_group(self, group_id: int) -> None: @@ -577,7 +571,6 @@ class ZHAGateway: if tasks: await asyncio.gather(*tasks) self.application_controller.groups.pop(group_id) - self.zha_storage.async_delete_group(group) async def shutdown(self): """Stop ZHA Controller Application.""" diff --git a/homeassistant/components/zha/core/group.py b/homeassistant/components/zha/core/group.py index e6b2dee0625..4fc86012d1a 100644 --- a/homeassistant/components/zha/core/group.py +++ b/homeassistant/components/zha/core/group.py @@ -1,7 +1,7 @@ """Group for Zigbee Home Automation.""" import asyncio import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from zigpy.types.named import EUI64 @@ -28,7 +28,6 @@ class ZHAGroup(LogMixin): self.hass: HomeAssistantType = hass self._zigpy_group: ZigpyGroupType = zigpy_group self._zha_gateway: ZhaGatewayType = zha_gateway - self._entity_domain: str = None @property def name(self) -> str: @@ -45,16 +44,6 @@ class ZHAGroup(LogMixin): """Return the endpoint for this group.""" return self._zigpy_group.endpoint - @property - def entity_domain(self) -> Optional[str]: - """Return the domain that will be used for the entity representing this group.""" - return self._entity_domain - - @entity_domain.setter - def entity_domain(self, domain: Optional[str]) -> None: - """Set the domain that will be used for the entity representing this group.""" - self._entity_domain = domain - @property def members(self) -> List[ZhaDeviceType]: """Return the ZHA devices that are members of this group.""" @@ -106,22 +95,15 @@ class ZHAGroup(LogMixin): all_entity_ids.append(entity.entity_id) return all_entity_ids - @property - def domain_entity_ids(self) -> List[str]: + def get_domain_entity_ids(self, domain) -> List[str]: """Return entity ids from the entity domain for this group.""" - if self.entity_domain is None: - return domain_entity_ids: List[str] = [] for device in self.members: entities = async_entries_for_device( self._zha_gateway.ha_entity_registry, device.device_id ) domain_entity_ids.extend( - [ - entity.entity_id - for entity in entities - if entity.domain == self.entity_domain - ] + [entity.entity_id for entity in entities if entity.domain == domain] ) return domain_entity_ids @@ -130,7 +112,6 @@ class ZHAGroup(LogMixin): """Get ZHA group info.""" group_info: Dict[str, Any] = {} group_info["group_id"] = self.group_id - group_info["entity_domain"] = self.entity_domain group_info["name"] = self.name group_info["members"] = [ zha_device.async_get_info() for zha_device in self.members diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index b596eefb71a..29b71343245 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -32,7 +32,7 @@ from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType from .decorators import CALLABLE_T, DictRegistry, SetRegistry from .typing import ChannelType -GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH] +GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH, FAN] SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02 SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000 diff --git a/homeassistant/components/zha/core/store.py b/homeassistant/components/zha/core/store.py index 0cd9e045cb6..00a4942c7b7 100644 --- a/homeassistant/components/zha/core/store.py +++ b/homeassistant/components/zha/core/store.py @@ -10,7 +10,7 @@ from homeassistant.core import callback from homeassistant.helpers.typing import HomeAssistantType from homeassistant.loader import bind_hass -from .typing import ZhaDeviceType, ZhaGroupType +from .typing import ZhaDeviceType _LOGGER = logging.getLogger(__name__) @@ -30,15 +30,6 @@ class ZhaDeviceEntry: last_seen = attr.ib(type=float, default=None) -@attr.s(slots=True, frozen=True) -class ZhaGroupEntry: - """Zha Group storage Entry.""" - - name = attr.ib(type=str, default=None) - group_id = attr.ib(type=int, default=None) - entity_domain = attr.ib(type=float, default=None) - - class ZhaStorage: """Class to hold a registry of zha devices.""" @@ -46,7 +37,6 @@ class ZhaStorage: """Initialize the zha device storage.""" self.hass: HomeAssistantType = hass self.devices: MutableMapping[str, ZhaDeviceEntry] = {} - self.groups: MutableMapping[str, ZhaGroupEntry] = {} self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) @callback @@ -59,17 +49,6 @@ class ZhaStorage: return self.async_update_device(device) - @callback - def async_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry: - """Create a new ZhaGroupEntry.""" - group_entry: ZhaGroupEntry = ZhaGroupEntry( - name=group.name, - group_id=str(group.group_id), - entity_domain=group.entity_domain, - ) - self.groups[str(group.group_id)] = group_entry - return self.async_update_group(group) - @callback def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Create a new ZhaDeviceEntry.""" @@ -78,14 +57,6 @@ class ZhaStorage: return self.devices[ieee_str] return self.async_create_device(device) - @callback - def async_get_or_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry: - """Create a new ZhaGroupEntry.""" - group_id: str = str(group.group_id) - if group_id in self.groups: - return self.groups[group_id] - return self.async_create_group(group) - @callback def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Create or update a ZhaDeviceEntry.""" @@ -93,13 +64,6 @@ class ZhaStorage: return self.async_update_device(device) return self.async_create_device(device) - @callback - def async_create_or_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry: - """Create or update a ZhaGroupEntry.""" - if str(group.group_id) in self.groups: - return self.async_update_group(group) - return self.async_create_group(group) - @callback def async_delete_device(self, device: ZhaDeviceType) -> None: """Delete ZhaDeviceEntry.""" @@ -108,14 +72,6 @@ class ZhaStorage: del self.devices[ieee_str] self.async_schedule_save() - @callback - def async_delete_group(self, group: ZhaGroupType) -> None: - """Delete ZhaGroupEntry.""" - group_id: str = str(group.group_id) - if group_id in self.groups: - del self.groups[group_id] - self.async_schedule_save() - @callback def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Update name of ZhaDeviceEntry.""" @@ -129,25 +85,11 @@ class ZhaStorage: self.async_schedule_save() return new - @callback - def async_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry: - """Update name of ZhaGroupEntry.""" - group_id: str = str(group.group_id) - old = self.groups[group_id] - - changes = {} - changes["entity_domain"] = group.entity_domain - - new = self.groups[group_id] = attr.evolve(old, **changes) - self.async_schedule_save() - return new - async def async_load(self) -> None: """Load the registry of zha device entries.""" data = await self._store.async_load() devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict() - groups: "OrderedDict[str, ZhaGroupEntry]" = OrderedDict() if data is not None: for device in data["devices"]: @@ -157,18 +99,7 @@ class ZhaStorage: last_seen=device["last_seen"] if "last_seen" in device else None, ) - if "groups" in data: - for group in data["groups"]: - groups[group["group_id"]] = ZhaGroupEntry( - name=group["name"], - group_id=group["group_id"], - entity_domain=group["entity_domain"] - if "entity_domain" in group - else None, - ) - self.devices = devices - self.groups = groups @callback def async_schedule_save(self) -> None: @@ -189,14 +120,6 @@ class ZhaStorage: for entry in self.devices.values() ] - data["groups"] = [ - { - "name": entry.name, - "group_id": entry.group_id, - "entity_domain": entry.entity_domain, - } - for entry in self.groups.values() - ] return data diff --git a/homeassistant/components/zha/fan.py b/homeassistant/components/zha/fan.py index d04453cd675..027d0f8a1ee 100644 --- a/homeassistant/components/zha/fan.py +++ b/homeassistant/components/zha/fan.py @@ -1,6 +1,10 @@ """Fans on Zigbee Home Automation networks.""" import functools import logging +from typing import List, Optional + +from zigpy.exceptions import DeliveryError +import zigpy.zcl.clusters.hvac as hvac from homeassistant.components.fan import ( DOMAIN, @@ -11,8 +15,10 @@ from homeassistant.components.fan import ( SUPPORT_SET_SPEED, FanEntity, ) -from homeassistant.core import callback +from homeassistant.const import STATE_UNAVAILABLE +from homeassistant.core import CALLBACK_TYPE, State, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.event import async_track_state_change from .core import discovery from .core.const import ( @@ -21,9 +27,10 @@ from .core.const import ( DATA_ZHA_DISPATCHERS, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, + SIGNAL_REMOVE_GROUP, ) from .core.registries import ZHA_ENTITIES -from .entity import ZhaEntity +from .entity import BaseZhaEntity, ZhaEntity _LOGGER = logging.getLogger(__name__) @@ -49,6 +56,7 @@ SPEED_LIST = [ VALUE_TO_SPEED = dict(enumerate(SPEED_LIST)) SPEED_TO_VALUE = {speed: i for i, speed in enumerate(SPEED_LIST)} STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) +GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, DOMAIN) async def async_setup_entry(hass, config_entry, async_add_entities): @@ -65,31 +73,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) -@STRICT_MATCH(channel_names=CHANNEL_FAN) -class ZhaFan(ZhaEntity, FanEntity): - """Representation of a ZHA fan.""" +class BaseFan(BaseZhaEntity, FanEntity): + """Base representation of a ZHA fan.""" - def __init__(self, unique_id, zha_device, channels, **kwargs): - """Init this sensor.""" - super().__init__(unique_id, zha_device, channels, **kwargs) - self._fan_channel = self.cluster_channels.get(CHANNEL_FAN) - - async def async_added_to_hass(self): - """Run when about to be added to hass.""" - await super().async_added_to_hass() - await self.async_accept_signal( - self._fan_channel, SIGNAL_ATTR_UPDATED, self.async_set_state - ) - - @callback - def async_restore_last_state(self, last_state): - """Restore previous state.""" - self._state = VALUE_TO_SPEED.get(last_state.state, last_state.state) - - @property - def supported_features(self) -> int: - """Flag supported features.""" - return SUPPORT_SET_SPEED + def __init__(self, *args, **kwargs): + """Initialize the fan.""" + super().__init__(*args, **kwargs) + self._state = None + self._fan_channel = None @property def speed_list(self) -> list: @@ -109,15 +100,9 @@ class ZhaFan(ZhaEntity, FanEntity): return self._state != SPEED_OFF @property - def device_state_attributes(self): - """Return state attributes.""" - return self.state_attributes - - @callback - def async_set_state(self, attr_id, attr_name, value): - """Handle state update from channel.""" - self._state = VALUE_TO_SPEED.get(value, self._state) - self.async_write_ha_state() + def supported_features(self) -> int: + """Flag supported features.""" + return SUPPORT_SET_SPEED async def async_turn_on(self, speed: str = None, **kwargs) -> None: """Turn the entity on.""" @@ -135,6 +120,34 @@ class ZhaFan(ZhaEntity, FanEntity): await self._fan_channel.async_set_speed(SPEED_TO_VALUE[speed]) self.async_set_state(0, "fan_mode", speed) + +@STRICT_MATCH(channel_names=CHANNEL_FAN) +class ZhaFan(ZhaEntity, BaseFan): + """Representation of a ZHA fan.""" + + def __init__(self, unique_id, zha_device, channels, **kwargs): + """Init this sensor.""" + super().__init__(unique_id, zha_device, channels, **kwargs) + self._fan_channel = self.cluster_channels.get(CHANNEL_FAN) + + async def async_added_to_hass(self): + """Run when about to be added to hass.""" + await super().async_added_to_hass() + await self.async_accept_signal( + self._fan_channel, SIGNAL_ATTR_UPDATED, self.async_set_state + ) + + @callback + def async_restore_last_state(self, last_state): + """Restore previous state.""" + self._state = VALUE_TO_SPEED.get(last_state.state, last_state.state) + + @callback + def async_set_state(self, attr_id, attr_name, value): + """Handle state update from channel.""" + self._state = VALUE_TO_SPEED.get(value, self._state) + self.async_write_ha_state() + async def async_update(self): """Attempt to retrieve on off state from the fan.""" await super().async_update() @@ -142,3 +155,73 @@ class ZhaFan(ZhaEntity, FanEntity): state = await self._fan_channel.get_attribute_value("fan_mode") if state is not None: self._state = VALUE_TO_SPEED.get(state, self._state) + + +@GROUP_MATCH() +class FanGroup(BaseFan): + """Representation of a fan group.""" + + def __init__( + self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs + ) -> None: + """Initialize a fan group.""" + super().__init__(unique_id, zha_device, **kwargs) + self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}" + self._group_id: int = group_id + self._available: bool = False + self._entity_ids: List[str] = entity_ids + self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None + group = self.zha_device.gateway.get_group(self._group_id) + self._fan_channel = group.endpoint[hvac.Fan.cluster_id] + + # what should we do with this hack? + async def async_set_speed(value) -> None: + """Set the speed of the fan.""" + try: + await self._fan_channel.write_attributes({"fan_mode": value}) + except DeliveryError as ex: + self.error("Could not set speed: %s", ex) + return + + self._fan_channel.async_set_speed = async_set_speed + + async def async_added_to_hass(self) -> None: + """Register callbacks.""" + await super().async_added_to_hass() + await self.async_accept_signal( + None, + f"{SIGNAL_REMOVE_GROUP}_{self._group_id}", + self.async_remove, + signal_override=True, + ) + + @callback + def async_state_changed_listener( + entity_id: str, old_state: State, new_state: State + ): + """Handle child updates.""" + self.async_schedule_update_ha_state(True) + + self._async_unsub_state_changed = async_track_state_change( + self.hass, self._entity_ids, async_state_changed_listener + ) + await self.async_update() + + async def async_will_remove_from_hass(self) -> None: + """Handle removal from Home Assistant.""" + await super().async_will_remove_from_hass() + if self._async_unsub_state_changed is not None: + self._async_unsub_state_changed() + self._async_unsub_state_changed = None + + async def async_update(self): + """Attempt to retrieve on off state from the fan.""" + all_states = [self.hass.states.get(x) for x in self._entity_ids] + states: List[State] = list(filter(None, all_states)) + on_states: List[State] = [state for state in states if state.state != SPEED_OFF] + self._available = any(state.state != STATE_UNAVAILABLE for state in states) + # for now just use first non off state since its kind of arbitrary + if not on_states: + self._state = SPEED_OFF + else: + self._state = states[0].state diff --git a/tests/components/zha/test_fan.py b/tests/components/zha/test_fan.py index 5011a847a4e..399982df37a 100644 --- a/tests/components/zha/test_fan.py +++ b/tests/components/zha/test_fan.py @@ -2,10 +2,21 @@ from unittest.mock import call import pytest +import zigpy.profiles.zha as zha +import zigpy.zcl.clusters.general as general import zigpy.zcl.clusters.hvac as hvac from homeassistant.components import fan -from homeassistant.components.fan import ATTR_SPEED, DOMAIN, SERVICE_SET_SPEED +from homeassistant.components.fan import ( + ATTR_SPEED, + DOMAIN, + SERVICE_SET_SPEED, + SPEED_HIGH, + SPEED_MEDIUM, + SPEED_OFF, +) +from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN +from homeassistant.components.zha.core.discovery import GROUP_PROBE from homeassistant.const import ( ATTR_ENTITY_ID, SERVICE_TURN_OFF, @@ -17,11 +28,16 @@ from homeassistant.const import ( from .common import ( async_enable_traffic, + async_find_group_entity_id, async_test_rejoin, find_entity_id, + get_zha_gateway, send_attributes_report, ) +IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" + @pytest.fixture def zigpy_device(zigpy_device_mock): @@ -32,6 +48,66 @@ def zigpy_device(zigpy_device_mock): return zigpy_device_mock(endpoints) +@pytest.fixture +async def coordinator(hass, zigpy_device_mock, zha_device_joined): + """Test zha fan platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee="00:15:8d:00:02:32:4f:32", + nwk=0x0000, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + +@pytest.fixture +async def device_fan_1(hass, zigpy_device_mock, zha_device_joined): + """Test zha fan platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [general.OnOff.cluster_id, hvac.Fan.cluster_id], + "out_clusters": [], + } + }, + ieee=IEEE_GROUPABLE_DEVICE, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + +@pytest.fixture +async def device_fan_2(hass, zigpy_device_mock, zha_device_joined): + """Test zha fan platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [ + general.OnOff.cluster_id, + hvac.Fan.cluster_id, + general.LevelControl.cluster_id, + ], + "out_clusters": [], + } + }, + ieee=IEEE_GROUPABLE_DEVICE2, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + async def test_fan(hass, zha_device_joined_restored, zigpy_device): """Test zha fan platform.""" @@ -106,3 +182,87 @@ async def async_set_speed(hass, entity_id, speed=None): } await hass.services.async_call(DOMAIN, SERVICE_SET_SPEED, data, blocking=True) + + +async def async_test_zha_group_fan_entity( + hass, device_fan_1, device_fan_2, coordinator +): + """Test the fan entity for a ZHA group.""" + zha_gateway = get_zha_gateway(hass) + assert zha_gateway is not None + zha_gateway.coordinator_zha_device = coordinator + coordinator._zha_gateway = zha_gateway + device_fan_1._zha_gateway = zha_gateway + device_fan_2._zha_gateway = zha_gateway + member_ieee_addresses = [device_fan_1.ieee, device_fan_2.ieee] + + # test creating a group with 2 members + zha_group = await zha_gateway.async_create_zigpy_group( + "Test Group", member_ieee_addresses + ) + await hass.async_block_till_done() + + assert zha_group is not None + assert len(zha_group.members) == 2 + for member in zha_group.members: + assert member.ieee in member_ieee_addresses + + entity_domains = GROUP_PROBE.determine_entity_domains(zha_group) + assert len(entity_domains) == 2 + + assert LIGHT_DOMAIN in entity_domains + assert DOMAIN in entity_domains + + entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group) + assert hass.states.get(entity_id) is not None + + group_fan_cluster = zha_group.endpoint[hvac.Fan.cluster_id] + dev1_fan_cluster = device_fan_1.endpoints[1].fan + dev2_fan_cluster = device_fan_2.endpoints[1].fan + + # test that the lights were created and that they are unavailable + assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + + # allow traffic to flow through the gateway and device + await async_enable_traffic(hass, zha_group.members) + + # test that the fan group entity was created and is off + assert hass.states.get(entity_id).state == STATE_OFF + + # turn on from HA + group_fan_cluster.write_attributes.reset_mock() + await async_turn_on(hass, entity_id) + assert len(group_fan_cluster.write_attributes.mock_calls) == 1 + assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 2}) + assert hass.states.get(entity_id).state == SPEED_MEDIUM + + # turn off from HA + group_fan_cluster.write_attributes.reset_mock() + await async_turn_off(hass, entity_id) + assert len(group_fan_cluster.write_attributes.mock_calls) == 1 + assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 0}) + assert hass.states.get(entity_id).state == STATE_OFF + + # change speed from HA + group_fan_cluster.write_attributes.reset_mock() + await async_set_speed(hass, entity_id, speed=fan.SPEED_HIGH) + assert len(group_fan_cluster.write_attributes.mock_calls) == 1 + assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 3}) + assert hass.states.get(entity_id).state == SPEED_HIGH + + # test some of the group logic to make sure we key off states correctly + await dev1_fan_cluster.async_set_speed(SPEED_OFF) + await dev2_fan_cluster.async_set_speed(SPEED_OFF) + + # test that group fan is off + assert hass.states.get(entity_id).state == STATE_OFF + + await dev1_fan_cluster.async_set_speed(SPEED_MEDIUM) + + # test that group fan is speed medium + assert hass.states.get(entity_id).state == SPEED_MEDIUM + + await dev1_fan_cluster.async_set_speed(SPEED_OFF) + + # test that group fan is now off + assert hass.states.get(entity_id).state == STATE_OFF diff --git a/tests/components/zha/test_gateway.py b/tests/components/zha/test_gateway.py index 3bb98522814..80d96fa55bd 100644 --- a/tests/components/zha/test_gateway.py +++ b/tests/components/zha/test_gateway.py @@ -134,7 +134,6 @@ async def test_gateway_group_methods(hass, device_light_1, device_light_2, coord await hass.async_block_till_done() assert zha_group is not None - assert zha_group.entity_domain == LIGHT_DOMAIN assert len(zha_group.members) == 2 for member in zha_group.members: assert member.ieee in member_ieee_addresses @@ -162,7 +161,6 @@ async def test_gateway_group_methods(hass, device_light_1, device_light_2, coord await hass.async_block_till_done() assert zha_group is not None - assert zha_group.entity_domain is None assert len(zha_group.members) == 1 for member in zha_group.members: assert member.ieee in [device_light_1.ieee] diff --git a/tests/components/zha/test_light.py b/tests/components/zha/test_light.py index c6bafa45aea..f832b9e86e0 100644 --- a/tests/components/zha/test_light.py +++ b/tests/components/zha/test_light.py @@ -432,7 +432,6 @@ async def async_test_zha_group_light_entity( await hass.async_block_till_done() assert zha_group is not None - assert zha_group.entity_domain == DOMAIN assert len(zha_group.members) == 2 for member in zha_group.members: assert member.ieee in member_ieee_addresses diff --git a/tests/components/zha/test_switch.py b/tests/components/zha/test_switch.py index adaaa7c2a2f..ed5d228ab88 100644 --- a/tests/components/zha/test_switch.py +++ b/tests/components/zha/test_switch.py @@ -173,7 +173,6 @@ async def async_test_zha_group_switch_entity( await hass.async_block_till_done() assert zha_group is not None - assert zha_group.entity_domain == DOMAIN assert len(zha_group.members) == 2 for member in zha_group.members: assert member.ieee in member_ieee_addresses