Add entities for ZHA fan groups (#33291)

* start of fan groups

* update fan classes

* update group entity domains

* add set speed

* update discovery for multiple entities for groups

* add fan group entity tests

* cleanup const

* cleanup entity_domain usage

* remove bad super call

* remove bad update line

* fix set speed on fan group

* change comparison

* pythonic list

* discovery guards

* Update homeassistant/components/zha/core/discovery.py

Co-Authored-By: Alexei Chetroi <lexoid@gmail.com>

Co-authored-by: Alexei Chetroi <lexoid@gmail.com>
This commit is contained in:
David F. Mulcahey 2020-03-26 22:19:48 -04:00 committed by GitHub
parent c89975adf6
commit 4f767dd3ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 316 additions and 188 deletions

View File

@ -23,7 +23,6 @@ ATTR_COMMAND_TYPE = "command_type"
ATTR_DEVICE_IEEE = "device_ieee" ATTR_DEVICE_IEEE = "device_ieee"
ATTR_DEVICE_TYPE = "device_type" ATTR_DEVICE_TYPE = "device_type"
ATTR_ENDPOINT_ID = "endpoint_id" ATTR_ENDPOINT_ID = "endpoint_id"
ATTR_ENTITY_DOMAIN = "entity_domain"
ATTR_IEEE = "ieee" ATTR_IEEE = "ieee"
ATTR_LAST_SEEN = "last_seen" ATTR_LAST_SEEN = "last_seen"
ATTR_LEVEL = "level" ATTR_LEVEL = "level"

View File

@ -6,6 +6,7 @@ from typing import Callable, List, Tuple
from homeassistant import const as ha_const from homeassistant import const as ha_const
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity_registry import async_entries_for_device from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
@ -182,59 +183,48 @@ class GroupProbe:
) )
return return
if group.entity_domain is None: entity_domains = GroupProbe.determine_entity_domains(self._hass, group)
_LOGGER.debug(
"Group: %s:0x%04x has no user set entity domain - attempting entity domain discovery",
group.name,
group.group_id,
)
group.entity_domain = GroupProbe.determine_default_entity_domain(
self._hass, group
)
if group.entity_domain is None: if not entity_domains:
return return
_LOGGER.debug(
"Group: %s:0x%04x has an entity domain of: %s after discovery",
group.name,
group.group_id,
group.entity_domain,
)
zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(group.entity_domain) for domain in entity_domains:
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain)
if entity_class is None: if entity_class is None:
return continue
self._hass.data[zha_const.DATA_ZHA][domain].append(
self._hass.data[zha_const.DATA_ZHA][group.entity_domain].append(
( (
entity_class, entity_class,
( (
group.domain_entity_ids, group.get_domain_entity_ids(domain),
f"{group.entity_domain}_group_{group.group_id}", f"{domain}_group_{group.group_id}",
group.group_id, group.group_id,
zha_gateway.coordinator_zha_device, zha_gateway.coordinator_zha_device,
), ),
) )
) )
async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES)
@staticmethod @staticmethod
def determine_default_entity_domain( def determine_entity_domains(
hass: HomeAssistantType, group: zha_typing.ZhaGroupType hass: HomeAssistantType, group: zha_typing.ZhaGroupType
): ) -> List[str]:
"""Determine the default entity domain for this group.""" """Determine the entity domains for this group."""
entity_domains: List[str] = []
if len(group.members) < 2: if len(group.members) < 2:
_LOGGER.debug( _LOGGER.debug(
"Group: %s:0x%04x has less than 2 members so cannot default an entity domain", "Group: %s:0x%04x has less than 2 members so cannot default an entity domain",
group.name, group.name,
group.group_id, group.group_id,
) )
return None return entity_domains
zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
all_domain_occurrences = [] all_domain_occurrences = []
for device in group.members: for device in group.members:
if device.is_coordinator:
continue
entities = async_entries_for_device( entities = async_entries_for_device(
zha_gateway.ha_entity_registry, device.device_id zha_gateway.ha_entity_registry, device.device_id
) )
@ -245,15 +235,18 @@ class GroupProbe:
if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS
] ]
) )
if not all_domain_occurrences:
return entity_domains
# get all domains we care about if there are more than 2 entities of this domain
counts = Counter(all_domain_occurrences) counts = Counter(all_domain_occurrences)
domain = counts.most_common(1)[0][0] entity_domains = [domain[0] for domain in counts.items() if domain[1] >= 2]
_LOGGER.debug( _LOGGER.debug(
"The default entity domain is: %s for group: %s:0x%04x", "The entity domains are: %s for group: %s:0x%04x",
domain, entity_domains,
group.name, group.name,
group.group_id, group.group_id,
) )
return domain return entity_domains
PROBE = ProbeEndpoint() PROBE = ProbeEndpoint()

View File

@ -445,8 +445,6 @@ class ZHAGateway:
if zha_group is None: if zha_group is None:
zha_group = ZHAGroup(self._hass, self, zigpy_group) zha_group = ZHAGroup(self._hass, self, zigpy_group)
self._groups[zigpy_group.group_id] = zha_group self._groups[zigpy_group.group_id] = zha_group
group_entry = self.zha_storage.async_get_or_create_group(zha_group)
zha_group.entity_domain = group_entry.entity_domain
return zha_group return zha_group
@callback @callback
@ -469,8 +467,6 @@ class ZHAGateway:
"""Update the devices in the store.""" """Update the devices in the store."""
for device in self.devices.values(): for device in self.devices.values():
self.zha_storage.async_update_device(device) self.zha_storage.async_update_device(device)
for group in self.groups.values():
self.zha_storage.async_update_group(group)
await self.zha_storage.async_save() await self.zha_storage.async_save()
async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType): async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType):
@ -559,9 +555,7 @@ class ZHAGateway:
zha_group.group_id, zha_group.group_id,
) )
discovery.GROUP_PROBE.discover_group_entities(zha_group) discovery.GROUP_PROBE.discover_group_entities(zha_group)
if zha_group.entity_domain is not None:
self.zha_storage.async_update_group(zha_group)
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
return zha_group return zha_group
async def async_remove_zigpy_group(self, group_id: int) -> None: async def async_remove_zigpy_group(self, group_id: int) -> None:
@ -577,7 +571,6 @@ class ZHAGateway:
if tasks: if tasks:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
self.application_controller.groups.pop(group_id) self.application_controller.groups.pop(group_id)
self.zha_storage.async_delete_group(group)
async def shutdown(self): async def shutdown(self):
"""Stop ZHA Controller Application.""" """Stop ZHA Controller Application."""

View File

@ -1,7 +1,7 @@
"""Group for Zigbee Home Automation.""" """Group for Zigbee Home Automation."""
import asyncio import asyncio
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
from zigpy.types.named import EUI64 from zigpy.types.named import EUI64
@ -28,7 +28,6 @@ class ZHAGroup(LogMixin):
self.hass: HomeAssistantType = hass self.hass: HomeAssistantType = hass
self._zigpy_group: ZigpyGroupType = zigpy_group self._zigpy_group: ZigpyGroupType = zigpy_group
self._zha_gateway: ZhaGatewayType = zha_gateway self._zha_gateway: ZhaGatewayType = zha_gateway
self._entity_domain: str = None
@property @property
def name(self) -> str: def name(self) -> str:
@ -45,16 +44,6 @@ class ZHAGroup(LogMixin):
"""Return the endpoint for this group.""" """Return the endpoint for this group."""
return self._zigpy_group.endpoint return self._zigpy_group.endpoint
@property
def entity_domain(self) -> Optional[str]:
"""Return the domain that will be used for the entity representing this group."""
return self._entity_domain
@entity_domain.setter
def entity_domain(self, domain: Optional[str]) -> None:
"""Set the domain that will be used for the entity representing this group."""
self._entity_domain = domain
@property @property
def members(self) -> List[ZhaDeviceType]: def members(self) -> List[ZhaDeviceType]:
"""Return the ZHA devices that are members of this group.""" """Return the ZHA devices that are members of this group."""
@ -106,22 +95,15 @@ class ZHAGroup(LogMixin):
all_entity_ids.append(entity.entity_id) all_entity_ids.append(entity.entity_id)
return all_entity_ids return all_entity_ids
@property def get_domain_entity_ids(self, domain) -> List[str]:
def domain_entity_ids(self) -> List[str]:
"""Return entity ids from the entity domain for this group.""" """Return entity ids from the entity domain for this group."""
if self.entity_domain is None:
return
domain_entity_ids: List[str] = [] domain_entity_ids: List[str] = []
for device in self.members: for device in self.members:
entities = async_entries_for_device( entities = async_entries_for_device(
self._zha_gateway.ha_entity_registry, device.device_id self._zha_gateway.ha_entity_registry, device.device_id
) )
domain_entity_ids.extend( domain_entity_ids.extend(
[ [entity.entity_id for entity in entities if entity.domain == domain]
entity.entity_id
for entity in entities
if entity.domain == self.entity_domain
]
) )
return domain_entity_ids return domain_entity_ids
@ -130,7 +112,6 @@ class ZHAGroup(LogMixin):
"""Get ZHA group info.""" """Get ZHA group info."""
group_info: Dict[str, Any] = {} group_info: Dict[str, Any] = {}
group_info["group_id"] = self.group_id group_info["group_id"] = self.group_id
group_info["entity_domain"] = self.entity_domain
group_info["name"] = self.name group_info["name"] = self.name
group_info["members"] = [ group_info["members"] = [
zha_device.async_get_info() for zha_device in self.members zha_device.async_get_info() for zha_device in self.members

View File

@ -32,7 +32,7 @@ from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType
from .decorators import CALLABLE_T, DictRegistry, SetRegistry from .decorators import CALLABLE_T, DictRegistry, SetRegistry
from .typing import ChannelType from .typing import ChannelType
GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH] GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH, FAN]
SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02 SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02
SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000 SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000

View File

@ -10,7 +10,7 @@ from homeassistant.core import callback
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from .typing import ZhaDeviceType, ZhaGroupType from .typing import ZhaDeviceType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -30,15 +30,6 @@ class ZhaDeviceEntry:
last_seen = attr.ib(type=float, default=None) last_seen = attr.ib(type=float, default=None)
@attr.s(slots=True, frozen=True)
class ZhaGroupEntry:
"""Zha Group storage Entry."""
name = attr.ib(type=str, default=None)
group_id = attr.ib(type=int, default=None)
entity_domain = attr.ib(type=float, default=None)
class ZhaStorage: class ZhaStorage:
"""Class to hold a registry of zha devices.""" """Class to hold a registry of zha devices."""
@ -46,7 +37,6 @@ class ZhaStorage:
"""Initialize the zha device storage.""" """Initialize the zha device storage."""
self.hass: HomeAssistantType = hass self.hass: HomeAssistantType = hass
self.devices: MutableMapping[str, ZhaDeviceEntry] = {} self.devices: MutableMapping[str, ZhaDeviceEntry] = {}
self.groups: MutableMapping[str, ZhaGroupEntry] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback @callback
@ -59,17 +49,6 @@ class ZhaStorage:
return self.async_update_device(device) return self.async_update_device(device)
@callback
def async_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Create a new ZhaGroupEntry."""
group_entry: ZhaGroupEntry = ZhaGroupEntry(
name=group.name,
group_id=str(group.group_id),
entity_domain=group.entity_domain,
)
self.groups[str(group.group_id)] = group_entry
return self.async_update_group(group)
@callback @callback
def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Create a new ZhaDeviceEntry.""" """Create a new ZhaDeviceEntry."""
@ -78,14 +57,6 @@ class ZhaStorage:
return self.devices[ieee_str] return self.devices[ieee_str]
return self.async_create_device(device) return self.async_create_device(device)
@callback
def async_get_or_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Create a new ZhaGroupEntry."""
group_id: str = str(group.group_id)
if group_id in self.groups:
return self.groups[group_id]
return self.async_create_group(group)
@callback @callback
def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Create or update a ZhaDeviceEntry.""" """Create or update a ZhaDeviceEntry."""
@ -93,13 +64,6 @@ class ZhaStorage:
return self.async_update_device(device) return self.async_update_device(device)
return self.async_create_device(device) return self.async_create_device(device)
@callback
def async_create_or_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Create or update a ZhaGroupEntry."""
if str(group.group_id) in self.groups:
return self.async_update_group(group)
return self.async_create_group(group)
@callback @callback
def async_delete_device(self, device: ZhaDeviceType) -> None: def async_delete_device(self, device: ZhaDeviceType) -> None:
"""Delete ZhaDeviceEntry.""" """Delete ZhaDeviceEntry."""
@ -108,14 +72,6 @@ class ZhaStorage:
del self.devices[ieee_str] del self.devices[ieee_str]
self.async_schedule_save() self.async_schedule_save()
@callback
def async_delete_group(self, group: ZhaGroupType) -> None:
"""Delete ZhaGroupEntry."""
group_id: str = str(group.group_id)
if group_id in self.groups:
del self.groups[group_id]
self.async_schedule_save()
@callback @callback
def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Update name of ZhaDeviceEntry.""" """Update name of ZhaDeviceEntry."""
@ -129,25 +85,11 @@ class ZhaStorage:
self.async_schedule_save() self.async_schedule_save()
return new return new
@callback
def async_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Update name of ZhaGroupEntry."""
group_id: str = str(group.group_id)
old = self.groups[group_id]
changes = {}
changes["entity_domain"] = group.entity_domain
new = self.groups[group_id] = attr.evolve(old, **changes)
self.async_schedule_save()
return new
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the registry of zha device entries.""" """Load the registry of zha device entries."""
data = await self._store.async_load() data = await self._store.async_load()
devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict() devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict()
groups: "OrderedDict[str, ZhaGroupEntry]" = OrderedDict()
if data is not None: if data is not None:
for device in data["devices"]: for device in data["devices"]:
@ -157,18 +99,7 @@ class ZhaStorage:
last_seen=device["last_seen"] if "last_seen" in device else None, last_seen=device["last_seen"] if "last_seen" in device else None,
) )
if "groups" in data:
for group in data["groups"]:
groups[group["group_id"]] = ZhaGroupEntry(
name=group["name"],
group_id=group["group_id"],
entity_domain=group["entity_domain"]
if "entity_domain" in group
else None,
)
self.devices = devices self.devices = devices
self.groups = groups
@callback @callback
def async_schedule_save(self) -> None: def async_schedule_save(self) -> None:
@ -189,14 +120,6 @@ class ZhaStorage:
for entry in self.devices.values() for entry in self.devices.values()
] ]
data["groups"] = [
{
"name": entry.name,
"group_id": entry.group_id,
"entity_domain": entry.entity_domain,
}
for entry in self.groups.values()
]
return data return data

View File

@ -1,6 +1,10 @@
"""Fans on Zigbee Home Automation networks.""" """Fans on Zigbee Home Automation networks."""
import functools import functools
import logging import logging
from typing import List, Optional
from zigpy.exceptions import DeliveryError
import zigpy.zcl.clusters.hvac as hvac
from homeassistant.components.fan import ( from homeassistant.components.fan import (
DOMAIN, DOMAIN,
@ -11,8 +15,10 @@ from homeassistant.components.fan import (
SUPPORT_SET_SPEED, SUPPORT_SET_SPEED,
FanEntity, FanEntity,
) )
from homeassistant.core import callback from homeassistant.const import STATE_UNAVAILABLE
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_state_change
from .core import discovery from .core import discovery
from .core.const import ( from .core.const import (
@ -21,9 +27,10 @@ from .core.const import (
DATA_ZHA_DISPATCHERS, DATA_ZHA_DISPATCHERS,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
SIGNAL_REMOVE_GROUP,
) )
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import BaseZhaEntity, ZhaEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -49,6 +56,7 @@ SPEED_LIST = [
VALUE_TO_SPEED = dict(enumerate(SPEED_LIST)) VALUE_TO_SPEED = dict(enumerate(SPEED_LIST))
SPEED_TO_VALUE = {speed: i for i, speed in enumerate(SPEED_LIST)} SPEED_TO_VALUE = {speed: i for i, speed in enumerate(SPEED_LIST)}
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, DOMAIN)
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):
@ -65,31 +73,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
@STRICT_MATCH(channel_names=CHANNEL_FAN) class BaseFan(BaseZhaEntity, FanEntity):
class ZhaFan(ZhaEntity, FanEntity): """Base representation of a ZHA fan."""
"""Representation of a ZHA fan."""
def __init__(self, unique_id, zha_device, channels, **kwargs): def __init__(self, *args, **kwargs):
"""Init this sensor.""" """Initialize the fan."""
super().__init__(unique_id, zha_device, channels, **kwargs) super().__init__(*args, **kwargs)
self._fan_channel = self.cluster_channels.get(CHANNEL_FAN) self._state = None
self._fan_channel = None
async def async_added_to_hass(self):
"""Run when about to be added to hass."""
await super().async_added_to_hass()
await self.async_accept_signal(
self._fan_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
)
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
self._state = VALUE_TO_SPEED.get(last_state.state, last_state.state)
@property
def supported_features(self) -> int:
"""Flag supported features."""
return SUPPORT_SET_SPEED
@property @property
def speed_list(self) -> list: def speed_list(self) -> list:
@ -109,15 +100,9 @@ class ZhaFan(ZhaEntity, FanEntity):
return self._state != SPEED_OFF return self._state != SPEED_OFF
@property @property
def device_state_attributes(self): def supported_features(self) -> int:
"""Return state attributes.""" """Flag supported features."""
return self.state_attributes return SUPPORT_SET_SPEED
@callback
def async_set_state(self, attr_id, attr_name, value):
"""Handle state update from channel."""
self._state = VALUE_TO_SPEED.get(value, self._state)
self.async_write_ha_state()
async def async_turn_on(self, speed: str = None, **kwargs) -> None: async def async_turn_on(self, speed: str = None, **kwargs) -> None:
"""Turn the entity on.""" """Turn the entity on."""
@ -135,6 +120,34 @@ class ZhaFan(ZhaEntity, FanEntity):
await self._fan_channel.async_set_speed(SPEED_TO_VALUE[speed]) await self._fan_channel.async_set_speed(SPEED_TO_VALUE[speed])
self.async_set_state(0, "fan_mode", speed) self.async_set_state(0, "fan_mode", speed)
@STRICT_MATCH(channel_names=CHANNEL_FAN)
class ZhaFan(ZhaEntity, BaseFan):
"""Representation of a ZHA fan."""
def __init__(self, unique_id, zha_device, channels, **kwargs):
"""Init this sensor."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._fan_channel = self.cluster_channels.get(CHANNEL_FAN)
async def async_added_to_hass(self):
"""Run when about to be added to hass."""
await super().async_added_to_hass()
await self.async_accept_signal(
self._fan_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
)
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
self._state = VALUE_TO_SPEED.get(last_state.state, last_state.state)
@callback
def async_set_state(self, attr_id, attr_name, value):
"""Handle state update from channel."""
self._state = VALUE_TO_SPEED.get(value, self._state)
self.async_write_ha_state()
async def async_update(self): async def async_update(self):
"""Attempt to retrieve on off state from the fan.""" """Attempt to retrieve on off state from the fan."""
await super().async_update() await super().async_update()
@ -142,3 +155,73 @@ class ZhaFan(ZhaEntity, FanEntity):
state = await self._fan_channel.get_attribute_value("fan_mode") state = await self._fan_channel.get_attribute_value("fan_mode")
if state is not None: if state is not None:
self._state = VALUE_TO_SPEED.get(state, self._state) self._state = VALUE_TO_SPEED.get(state, self._state)
@GROUP_MATCH()
class FanGroup(BaseFan):
"""Representation of a fan group."""
def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a fan group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
self._available: bool = False
self._entity_ids: List[str] = entity_ids
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
group = self.zha_device.gateway.get_group(self._group_id)
self._fan_channel = group.endpoint[hvac.Fan.cluster_id]
# what should we do with this hack?
async def async_set_speed(value) -> None:
"""Set the speed of the fan."""
try:
await self._fan_channel.write_attributes({"fan_mode": value})
except DeliveryError as ex:
self.error("Could not set speed: %s", ex)
return
self._fan_channel.async_set_speed = async_set_speed
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self):
"""Attempt to retrieve on off state from the fan."""
all_states = [self.hass.states.get(x) for x in self._entity_ids]
states: List[State] = list(filter(None, all_states))
on_states: List[State] = [state for state in states if state.state != SPEED_OFF]
self._available = any(state.state != STATE_UNAVAILABLE for state in states)
# for now just use first non off state since its kind of arbitrary
if not on_states:
self._state = SPEED_OFF
else:
self._state = states[0].state

View File

@ -2,10 +2,21 @@
from unittest.mock import call from unittest.mock import call
import pytest import pytest
import zigpy.profiles.zha as zha
import zigpy.zcl.clusters.general as general
import zigpy.zcl.clusters.hvac as hvac import zigpy.zcl.clusters.hvac as hvac
from homeassistant.components import fan from homeassistant.components import fan
from homeassistant.components.fan import ATTR_SPEED, DOMAIN, SERVICE_SET_SPEED from homeassistant.components.fan import (
ATTR_SPEED,
DOMAIN,
SERVICE_SET_SPEED,
SPEED_HIGH,
SPEED_MEDIUM,
SPEED_OFF,
)
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.components.zha.core.discovery import GROUP_PROBE
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,
@ -17,11 +28,16 @@ from homeassistant.const import (
from .common import ( from .common import (
async_enable_traffic, async_enable_traffic,
async_find_group_entity_id,
async_test_rejoin, async_test_rejoin,
find_entity_id, find_entity_id,
get_zha_gateway,
send_attributes_report, send_attributes_report,
) )
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
@pytest.fixture @pytest.fixture
def zigpy_device(zigpy_device_mock): def zigpy_device(zigpy_device_mock):
@ -32,6 +48,66 @@ def zigpy_device(zigpy_device_mock):
return zigpy_device_mock(endpoints) return zigpy_device_mock(endpoints)
@pytest.fixture
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
"""Test zha fan platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee="00:15:8d:00:02:32:4f:32",
nwk=0x0000,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@pytest.fixture
async def device_fan_1(hass, zigpy_device_mock, zha_device_joined):
"""Test zha fan platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [general.OnOff.cluster_id, hvac.Fan.cluster_id],
"out_clusters": [],
}
},
ieee=IEEE_GROUPABLE_DEVICE,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@pytest.fixture
async def device_fan_2(hass, zigpy_device_mock, zha_device_joined):
"""Test zha fan platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [
general.OnOff.cluster_id,
hvac.Fan.cluster_id,
general.LevelControl.cluster_id,
],
"out_clusters": [],
}
},
ieee=IEEE_GROUPABLE_DEVICE2,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
async def test_fan(hass, zha_device_joined_restored, zigpy_device): async def test_fan(hass, zha_device_joined_restored, zigpy_device):
"""Test zha fan platform.""" """Test zha fan platform."""
@ -106,3 +182,87 @@ async def async_set_speed(hass, entity_id, speed=None):
} }
await hass.services.async_call(DOMAIN, SERVICE_SET_SPEED, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_SET_SPEED, data, blocking=True)
async def async_test_zha_group_fan_entity(
hass, device_fan_1, device_fan_2, coordinator
):
"""Test the fan entity for a ZHA group."""
zha_gateway = get_zha_gateway(hass)
assert zha_gateway is not None
zha_gateway.coordinator_zha_device = coordinator
coordinator._zha_gateway = zha_gateway
device_fan_1._zha_gateway = zha_gateway
device_fan_2._zha_gateway = zha_gateway
member_ieee_addresses = [device_fan_1.ieee, device_fan_2.ieee]
# test creating a group with 2 members
zha_group = await zha_gateway.async_create_zigpy_group(
"Test Group", member_ieee_addresses
)
await hass.async_block_till_done()
assert zha_group is not None
assert len(zha_group.members) == 2
for member in zha_group.members:
assert member.ieee in member_ieee_addresses
entity_domains = GROUP_PROBE.determine_entity_domains(zha_group)
assert len(entity_domains) == 2
assert LIGHT_DOMAIN in entity_domains
assert DOMAIN in entity_domains
entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group)
assert hass.states.get(entity_id) is not None
group_fan_cluster = zha_group.endpoint[hvac.Fan.cluster_id]
dev1_fan_cluster = device_fan_1.endpoints[1].fan
dev2_fan_cluster = device_fan_2.endpoints[1].fan
# test that the lights were created and that they are unavailable
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, zha_group.members)
# test that the fan group entity was created and is off
assert hass.states.get(entity_id).state == STATE_OFF
# turn on from HA
group_fan_cluster.write_attributes.reset_mock()
await async_turn_on(hass, entity_id)
assert len(group_fan_cluster.write_attributes.mock_calls) == 1
assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 2})
assert hass.states.get(entity_id).state == SPEED_MEDIUM
# turn off from HA
group_fan_cluster.write_attributes.reset_mock()
await async_turn_off(hass, entity_id)
assert len(group_fan_cluster.write_attributes.mock_calls) == 1
assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 0})
assert hass.states.get(entity_id).state == STATE_OFF
# change speed from HA
group_fan_cluster.write_attributes.reset_mock()
await async_set_speed(hass, entity_id, speed=fan.SPEED_HIGH)
assert len(group_fan_cluster.write_attributes.mock_calls) == 1
assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 3})
assert hass.states.get(entity_id).state == SPEED_HIGH
# test some of the group logic to make sure we key off states correctly
await dev1_fan_cluster.async_set_speed(SPEED_OFF)
await dev2_fan_cluster.async_set_speed(SPEED_OFF)
# test that group fan is off
assert hass.states.get(entity_id).state == STATE_OFF
await dev1_fan_cluster.async_set_speed(SPEED_MEDIUM)
# test that group fan is speed medium
assert hass.states.get(entity_id).state == SPEED_MEDIUM
await dev1_fan_cluster.async_set_speed(SPEED_OFF)
# test that group fan is now off
assert hass.states.get(entity_id).state == STATE_OFF

View File

@ -134,7 +134,6 @@ async def test_gateway_group_methods(hass, device_light_1, device_light_2, coord
await hass.async_block_till_done() await hass.async_block_till_done()
assert zha_group is not None assert zha_group is not None
assert zha_group.entity_domain == LIGHT_DOMAIN
assert len(zha_group.members) == 2 assert len(zha_group.members) == 2
for member in zha_group.members: for member in zha_group.members:
assert member.ieee in member_ieee_addresses assert member.ieee in member_ieee_addresses
@ -162,7 +161,6 @@ async def test_gateway_group_methods(hass, device_light_1, device_light_2, coord
await hass.async_block_till_done() await hass.async_block_till_done()
assert zha_group is not None assert zha_group is not None
assert zha_group.entity_domain is None
assert len(zha_group.members) == 1 assert len(zha_group.members) == 1
for member in zha_group.members: for member in zha_group.members:
assert member.ieee in [device_light_1.ieee] assert member.ieee in [device_light_1.ieee]

View File

@ -432,7 +432,6 @@ async def async_test_zha_group_light_entity(
await hass.async_block_till_done() await hass.async_block_till_done()
assert zha_group is not None assert zha_group is not None
assert zha_group.entity_domain == DOMAIN
assert len(zha_group.members) == 2 assert len(zha_group.members) == 2
for member in zha_group.members: for member in zha_group.members:
assert member.ieee in member_ieee_addresses assert member.ieee in member_ieee_addresses

View File

@ -173,7 +173,6 @@ async def async_test_zha_group_switch_entity(
await hass.async_block_till_done() await hass.async_block_till_done()
assert zha_group is not None assert zha_group is not None
assert zha_group.entity_domain == DOMAIN
assert len(zha_group.members) == 2 assert len(zha_group.members) == 2
for member in zha_group.members: for member in zha_group.members:
assert member.ieee in member_ieee_addresses assert member.ieee in member_ieee_addresses