Migrate internal ZHA data to a dataclasses (#100127)

* Cache device triggers on startup

* reorg zha init

* don't reuse gateway

* don't nuke yaml configuration

* review comments

* Add unit tests

* Do not cache device and entity registries

* [WIP] Wrap ZHA data in a dataclass

* [WIP] Get unit tests passing

* Use a helper function for getting the gateway object to fix annotations

* Remove `bridge_id`

* Fix typing issues with entity references in group websocket info

* Use `Platform` instead of `str` for entity platform matching

* Use `get_zha_gateway` in a few more places

* Fix flaky unit test

* Use `slots` for ZHA data

Co-authored-by: J. Nick Koston <nick@koston.org>

---------

Co-authored-by: David F. Mulcahey <david.mulcahey@icloud.com>
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
puddly 2023-09-11 21:39:33 +02:00 committed by GitHub
parent 5c206de906
commit cbb28b6943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 317 additions and 288 deletions

View File

@ -33,9 +33,6 @@ from .core.const import (
CONF_USB_PATH, CONF_USB_PATH,
CONF_ZIGPY, CONF_ZIGPY,
DATA_ZHA, DATA_ZHA,
DATA_ZHA_CONFIG,
DATA_ZHA_DEVICE_TRIGGER_CACHE,
DATA_ZHA_GATEWAY,
DOMAIN, DOMAIN,
PLATFORMS, PLATFORMS,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
@ -43,6 +40,7 @@ from .core.const import (
) )
from .core.device import get_device_automation_triggers from .core.device import get_device_automation_triggers
from .core.discovery import GROUP_PROBE from .core.discovery import GROUP_PROBE
from .core.helpers import ZHAData, get_zha_data
from .radio_manager import ZhaRadioManager from .radio_manager import ZhaRadioManager
DEVICE_CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_TYPE): cv.string}) DEVICE_CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_TYPE): cv.string})
@ -81,11 +79,9 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up ZHA from config.""" """Set up ZHA from config."""
hass.data[DATA_ZHA] = {} zha_data = ZHAData()
zha_data.yaml_config = config.get(DOMAIN, {})
if DOMAIN in config: hass.data[DATA_ZHA] = zha_data
conf = config[DOMAIN]
hass.data[DATA_ZHA][DATA_ZHA_CONFIG] = conf
return True return True
@ -120,14 +116,12 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
data[CONF_DEVICE][CONF_DEVICE_PATH] = cleaned_path data[CONF_DEVICE][CONF_DEVICE_PATH] = cleaned_path
hass.config_entries.async_update_entry(config_entry, data=data) hass.config_entries.async_update_entry(config_entry, data=data)
zha_data = hass.data.setdefault(DATA_ZHA, {}) zha_data = get_zha_data(hass)
config = zha_data.get(DATA_ZHA_CONFIG, {})
for platform in PLATFORMS: if zha_data.yaml_config.get(CONF_ENABLE_QUIRKS, True):
zha_data.setdefault(platform, []) setup_quirks(
custom_quirks_path=zha_data.yaml_config.get(CONF_CUSTOM_QUIRKS_PATH)
if config.get(CONF_ENABLE_QUIRKS, True): )
setup_quirks(custom_quirks_path=config.get(CONF_CUSTOM_QUIRKS_PATH))
# temporary code to remove the ZHA storage file from disk. # temporary code to remove the ZHA storage file from disk.
# this will be removed in 2022.10.0 # this will be removed in 2022.10.0
@ -139,8 +133,6 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
_LOGGER.debug("ZHA storage file does not exist or was already removed") _LOGGER.debug("ZHA storage file does not exist or was already removed")
# Load and cache device trigger information early # Load and cache device trigger information early
zha_data.setdefault(DATA_ZHA_DEVICE_TRIGGER_CACHE, {})
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
radio_mgr = ZhaRadioManager.from_config_entry(hass, config_entry) radio_mgr = ZhaRadioManager.from_config_entry(hass, config_entry)
@ -154,14 +146,14 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
if dev_entry is None: if dev_entry is None:
continue continue
zha_data[DATA_ZHA_DEVICE_TRIGGER_CACHE][dev_entry.id] = ( zha_data.device_trigger_cache[dev_entry.id] = (
str(dev.ieee), str(dev.ieee),
get_device_automation_triggers(dev), get_device_automation_triggers(dev),
) )
_LOGGER.debug("Trigger cache: %s", zha_data[DATA_ZHA_DEVICE_TRIGGER_CACHE]) _LOGGER.debug("Trigger cache: %s", zha_data.device_trigger_cache)
zha_gateway = ZHAGateway(hass, config, config_entry) zha_gateway = ZHAGateway(hass, zha_data.yaml_config, config_entry)
async def async_zha_shutdown(): async def async_zha_shutdown():
"""Handle shutdown tasks.""" """Handle shutdown tasks."""
@ -172,7 +164,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
# be in when we get here in failure cases # be in when we get here in failure cases
with contextlib.suppress(KeyError): with contextlib.suppress(KeyError):
for platform in PLATFORMS: for platform in PLATFORMS:
del hass.data[DATA_ZHA][platform] del zha_data.platforms[platform]
config_entry.async_on_unload(async_zha_shutdown) config_entry.async_on_unload(async_zha_shutdown)
@ -212,10 +204,8 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload ZHA config entry.""" """Unload ZHA config entry."""
try: zha_data = get_zha_data(hass)
del hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_data.gateway = None
except KeyError:
return False
GROUP_PROBE.cleanup() GROUP_PROBE.cleanup()
websocket_api.async_unload_api(hass) websocket_api.async_unload_api(hass)
@ -241,7 +231,7 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
CONF_DEVICE: {CONF_DEVICE_PATH: config_entry.data[CONF_USB_PATH]}, CONF_DEVICE: {CONF_DEVICE_PATH: config_entry.data[CONF_USB_PATH]},
} }
baudrate = hass.data[DATA_ZHA].get(DATA_ZHA_CONFIG, {}).get(CONF_BAUDRATE) baudrate = get_zha_data(hass).yaml_config.get(CONF_BAUDRATE)
if data[CONF_RADIO_TYPE] != RadioType.deconz and baudrate in BAUD_RATES: if data[CONF_RADIO_TYPE] != RadioType.deconz and baudrate in BAUD_RATES:
data[CONF_DEVICE][CONF_BAUDRATE] = baudrate data[CONF_DEVICE][CONF_BAUDRATE] = baudrate

View File

@ -35,11 +35,10 @@ from .core.const import (
CONF_ALARM_ARM_REQUIRES_CODE, CONF_ALARM_ARM_REQUIRES_CODE,
CONF_ALARM_FAILED_TRIES, CONF_ALARM_FAILED_TRIES,
CONF_ALARM_MASTER_CODE, CONF_ALARM_MASTER_CODE,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
ZHA_ALARM_OPTIONS, ZHA_ALARM_OPTIONS,
) )
from .core.helpers import async_get_zha_config_value from .core.helpers import async_get_zha_config_value, get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -65,7 +64,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation alarm control panel from config entry.""" """Set up the Zigbee Home Automation alarm control panel from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.ALARM_CONTROL_PANEL] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.ALARM_CONTROL_PANEL]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -9,33 +9,22 @@ from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH
from zigpy.types import Channels from zigpy.types import Channels
from zigpy.util import pick_optimal_channel from zigpy.util import pick_optimal_channel
from .core.const import ( from .core.const import CONF_RADIO_TYPE, DOMAIN, RadioType
CONF_RADIO_TYPE,
DATA_ZHA,
DATA_ZHA_CONFIG,
DATA_ZHA_GATEWAY,
DOMAIN,
RadioType,
)
from .core.gateway import ZHAGateway from .core.gateway import ZHAGateway
from .core.helpers import get_zha_data, get_zha_gateway
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
def _get_gateway(hass: HomeAssistant) -> ZHAGateway:
"""Get a reference to the ZHA gateway device."""
return hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
def _get_config_entry(hass: HomeAssistant) -> ConfigEntry: def _get_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Find the singleton ZHA config entry, if one exists.""" """Find the singleton ZHA config entry, if one exists."""
# If ZHA is already running, use its config entry # If ZHA is already running, use its config entry
try: try:
zha_gateway = _get_gateway(hass) zha_gateway = get_zha_gateway(hass)
except KeyError: except ValueError:
pass pass
else: else:
return zha_gateway.config_entry return zha_gateway.config_entry
@ -51,8 +40,7 @@ def _get_config_entry(hass: HomeAssistant) -> ConfigEntry:
def async_get_active_network_settings(hass: HomeAssistant) -> NetworkBackup: def async_get_active_network_settings(hass: HomeAssistant) -> NetworkBackup:
"""Get the network settings for the currently active ZHA network.""" """Get the network settings for the currently active ZHA network."""
zha_gateway: ZHAGateway = _get_gateway(hass) app = get_zha_gateway(hass).application_controller
app = zha_gateway.application_controller
return NetworkBackup( return NetworkBackup(
node_info=app.state.node_info, node_info=app.state.node_info,
@ -67,7 +55,7 @@ async def async_get_last_network_settings(
if config_entry is None: if config_entry is None:
config_entry = _get_config_entry(hass) config_entry = _get_config_entry(hass)
config = hass.data.get(DATA_ZHA, {}).get(DATA_ZHA_CONFIG, {}) config = get_zha_data(hass).yaml_config
zha_gateway = ZHAGateway(hass, config, config_entry) zha_gateway = ZHAGateway(hass, config, config_entry)
app_controller_cls, app_config = zha_gateway.get_application_controller_data() app_controller_cls, app_config = zha_gateway.get_application_controller_data()
@ -91,7 +79,7 @@ async def async_get_network_settings(
try: try:
return async_get_active_network_settings(hass) return async_get_active_network_settings(hass)
except KeyError: except ValueError:
return await async_get_last_network_settings(hass, config_entry) return await async_get_last_network_settings(hass, config_entry)
@ -120,8 +108,7 @@ async def async_change_channel(
) -> None: ) -> None:
"""Migrate the ZHA network to a new channel.""" """Migrate the ZHA network to a new channel."""
zha_gateway: ZHAGateway = _get_gateway(hass) app = get_zha_gateway(hass).application_controller
app = zha_gateway.application_controller
if new_channel == "auto": if new_channel == "auto":
channel_energy = await app.energy_scan( channel_energy = await app.energy_scan(

View File

@ -3,8 +3,7 @@ import logging
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .core import ZHAGateway from .core.helpers import get_zha_gateway
from .core.const import DATA_ZHA, DATA_ZHA_GATEWAY
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -13,7 +12,7 @@ async def async_pre_backup(hass: HomeAssistant) -> None:
"""Perform operations before a backup starts.""" """Perform operations before a backup starts."""
_LOGGER.debug("Performing coordinator backup") _LOGGER.debug("Performing coordinator backup")
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
await zha_gateway.application_controller.backups.create_backup(load_devices=True) await zha_gateway.application_controller.backups.create_backup(load_devices=True)

View File

@ -26,10 +26,10 @@ from .core.const import (
CLUSTER_HANDLER_OCCUPANCY, CLUSTER_HANDLER_OCCUPANCY,
CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_ON_OFF,
CLUSTER_HANDLER_ZONE, CLUSTER_HANDLER_ZONE,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -65,7 +65,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation binary sensor from config entry.""" """Set up the Zigbee Home Automation binary sensor from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.BINARY_SENSOR] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.BINARY_SENSOR]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -14,7 +14,8 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .core import discovery from .core import discovery
from .core.const import CLUSTER_HANDLER_IDENTIFY, DATA_ZHA, SIGNAL_ADD_ENTITIES from .core.const import CLUSTER_HANDLER_IDENTIFY, SIGNAL_ADD_ENTITIES
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -38,7 +39,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation button from config entry.""" """Set up the Zigbee Home Automation button from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.BUTTON] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.BUTTON]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -45,13 +45,13 @@ from .core import discovery
from .core.const import ( from .core.const import (
CLUSTER_HANDLER_FAN, CLUSTER_HANDLER_FAN,
CLUSTER_HANDLER_THERMOSTAT, CLUSTER_HANDLER_THERMOSTAT,
DATA_ZHA,
PRESET_COMPLEX, PRESET_COMPLEX,
PRESET_SCHEDULE, PRESET_SCHEDULE,
PRESET_TEMP_MANUAL, PRESET_TEMP_MANUAL,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -115,7 +115,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation sensor from config entry.""" """Set up the Zigbee Home Automation sensor from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.CLIMATE] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.CLIMATE]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,

View File

@ -184,7 +184,6 @@ CUSTOM_CONFIGURATION = "custom_configuration"
DATA_DEVICE_CONFIG = "zha_device_config" DATA_DEVICE_CONFIG = "zha_device_config"
DATA_ZHA = "zha" DATA_ZHA = "zha"
DATA_ZHA_CONFIG = "config" DATA_ZHA_CONFIG = "config"
DATA_ZHA_BRIDGE_ID = "zha_bridge_id"
DATA_ZHA_CORE_EVENTS = "zha_core_events" DATA_ZHA_CORE_EVENTS = "zha_core_events"
DATA_ZHA_DEVICE_TRIGGER_CACHE = "zha_device_trigger_cache" DATA_ZHA_DEVICE_TRIGGER_CACHE = "zha_device_trigger_cache"
DATA_ZHA_GATEWAY = "zha_gateway" DATA_ZHA_GATEWAY = "zha_gateway"

View File

@ -25,6 +25,7 @@ from homeassistant.backports.functools import cached_property
from homeassistant.const import ATTR_COMMAND, ATTR_DEVICE_ID, ATTR_NAME from homeassistant.const import ATTR_COMMAND, ATTR_DEVICE_ID, ATTR_NAME
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
@ -420,7 +421,9 @@ class ZHADevice(LogMixin):
"""Update device sw version.""" """Update device sw version."""
if self.device_id is None: if self.device_id is None:
return return
self._zha_gateway.ha_device_registry.async_update_device(
device_registry = dr.async_get(self.hass)
device_registry.async_update_device(
self.device_id, sw_version=f"0x{sw_version:08x}" self.device_id, sw_version=f"0x{sw_version:08x}"
) )
@ -658,7 +661,8 @@ class ZHADevice(LogMixin):
) )
device_info[ATTR_ENDPOINT_NAMES] = names device_info[ATTR_ENDPOINT_NAMES] = names
reg_device = self.gateway.ha_device_registry.async_get(self.device_id) device_registry = dr.async_get(self.hass)
reg_device = device_registry.async_get(self.device_id)
if reg_device is not None: if reg_device is not None:
device_info["user_given_name"] = reg_device.name_by_user device_info["user_given_name"] = reg_device.name_by_user
device_info["device_reg_id"] = reg_device.id device_info["device_reg_id"] = reg_device.id

View File

@ -4,10 +4,11 @@ from __future__ import annotations
from collections import Counter from collections import Counter
from collections.abc import Callable from collections.abc import Callable
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
from homeassistant.const import CONF_TYPE, Platform from homeassistant.const import CONF_TYPE, Platform
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
@ -49,12 +50,12 @@ from .cluster_handlers import ( # noqa: F401
security, security,
smartenergy, smartenergy,
) )
from .helpers import get_zha_data, get_zha_gateway
if TYPE_CHECKING: if TYPE_CHECKING:
from ..entity import ZhaEntity from ..entity import ZhaEntity
from .device import ZHADevice from .device import ZHADevice
from .endpoint import Endpoint from .endpoint import Endpoint
from .gateway import ZHAGateway
from .group import ZHAGroup from .group import ZHAGroup
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -113,6 +114,8 @@ class ProbeEndpoint:
platform = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type) platform = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type)
if platform and platform in zha_const.PLATFORMS: if platform and platform in zha_const.PLATFORMS:
platform = cast(Platform, platform)
cluster_handlers = endpoint.unclaimed_cluster_handlers() cluster_handlers = endpoint.unclaimed_cluster_handlers()
platform_entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity( platform_entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
platform, platform,
@ -263,9 +266,7 @@ class ProbeEndpoint:
def initialize(self, hass: HomeAssistant) -> None: def initialize(self, hass: HomeAssistant) -> None:
"""Update device overrides config.""" """Update device overrides config."""
zha_config: ConfigType = hass.data[zha_const.DATA_ZHA].get( zha_config = get_zha_data(hass).yaml_config
zha_const.DATA_ZHA_CONFIG, {}
)
if overrides := zha_config.get(zha_const.CONF_DEVICE_CONFIG): if overrides := zha_config.get(zha_const.CONF_DEVICE_CONFIG):
self._device_configs.update(overrides) self._device_configs.update(overrides)
@ -297,9 +298,7 @@ class GroupProbe:
@callback @callback
def _reprobe_group(self, group_id: int) -> None: def _reprobe_group(self, group_id: int) -> None:
"""Reprobe a group for entities after its members change.""" """Reprobe a group for entities after its members change."""
zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][ zha_gateway = get_zha_gateway(self._hass)
zha_const.DATA_ZHA_GATEWAY
]
if (zha_group := zha_gateway.groups.get(group_id)) is None: if (zha_group := zha_gateway.groups.get(group_id)) is None:
return return
self.discover_group_entities(zha_group) self.discover_group_entities(zha_group)
@ -321,14 +320,14 @@ class GroupProbe:
if not entity_domains: if not entity_domains:
return return
zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][ zha_data = get_zha_data(self._hass)
zha_const.DATA_ZHA_GATEWAY zha_gateway = get_zha_gateway(self._hass)
]
for domain in entity_domains: for domain in entity_domains:
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain) entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain)
if entity_class is None: if entity_class is None:
continue continue
self._hass.data[zha_const.DATA_ZHA][domain].append( zha_data.platforms[domain].append(
( (
entity_class, entity_class,
( (
@ -342,24 +341,26 @@ class GroupProbe:
async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES) async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES)
@staticmethod @staticmethod
def determine_entity_domains(hass: HomeAssistant, group: ZHAGroup) -> list[str]: def determine_entity_domains(
hass: HomeAssistant, group: ZHAGroup
) -> list[Platform]:
"""Determine the entity domains for this group.""" """Determine the entity domains for this group."""
entity_domains: list[str] = [] entity_registry = er.async_get(hass)
zha_gateway: ZHAGateway = hass.data[zha_const.DATA_ZHA][
zha_const.DATA_ZHA_GATEWAY entity_domains: list[Platform] = []
] all_domain_occurrences: list[Platform] = []
all_domain_occurrences = []
for member in group.members: for member in group.members:
if member.device.is_coordinator: if member.device.is_coordinator:
continue continue
entities = async_entries_for_device( entities = async_entries_for_device(
zha_gateway.ha_entity_registry, entity_registry,
member.device.device_id, member.device.device_id,
include_disabled_entities=True, include_disabled_entities=True,
) )
all_domain_occurrences.extend( all_domain_occurrences.extend(
[ [
entity.domain cast(Platform, entity.domain)
for entity in entities for entity in entities
if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS
] ]

View File

@ -16,6 +16,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from . import const, discovery, registries from . import const, discovery, registries
from .cluster_handlers import ClusterHandler from .cluster_handlers import ClusterHandler
from .cluster_handlers.general import MultistateInput from .cluster_handlers.general import MultistateInput
from .helpers import get_zha_data
if TYPE_CHECKING: if TYPE_CHECKING:
from .cluster_handlers import ClientClusterHandler from .cluster_handlers import ClientClusterHandler
@ -195,7 +196,7 @@ class Endpoint:
def async_new_entity( def async_new_entity(
self, self,
platform: Platform | str, platform: Platform,
entity_class: CALLABLE_T, entity_class: CALLABLE_T,
unique_id: str, unique_id: str,
cluster_handlers: list[ClusterHandler], cluster_handlers: list[ClusterHandler],
@ -206,7 +207,8 @@ class Endpoint:
if self.device.status == DeviceStatus.INITIALIZED: if self.device.status == DeviceStatus.INITIALIZED:
return return
self.device.hass.data[const.DATA_ZHA][platform].append( zha_data = get_zha_data(self.device.hass)
zha_data.platforms[platform].append(
(entity_class, (unique_id, self.device, cluster_handlers)) (entity_class, (unique_id, self.device, cluster_handlers))
) )

View File

@ -46,9 +46,6 @@ from .const import (
CONF_RADIO_TYPE, CONF_RADIO_TYPE,
CONF_USE_THREAD, CONF_USE_THREAD,
CONF_ZIGPY, CONF_ZIGPY,
DATA_ZHA,
DATA_ZHA_BRIDGE_ID,
DATA_ZHA_GATEWAY,
DEBUG_COMP_BELLOWS, DEBUG_COMP_BELLOWS,
DEBUG_COMP_ZHA, DEBUG_COMP_ZHA,
DEBUG_COMP_ZIGPY, DEBUG_COMP_ZIGPY,
@ -87,6 +84,7 @@ from .const import (
) )
from .device import DeviceStatus, ZHADevice from .device import DeviceStatus, ZHADevice
from .group import GroupMember, ZHAGroup from .group import GroupMember, ZHAGroup
from .helpers import get_zha_data
from .registries import GROUP_ENTITY_DOMAINS from .registries import GROUP_ENTITY_DOMAINS
if TYPE_CHECKING: if TYPE_CHECKING:
@ -123,8 +121,6 @@ class ZHAGateway:
"""Gateway that handles events that happen on the ZHA Zigbee network.""" """Gateway that handles events that happen on the ZHA Zigbee network."""
# -- Set in async_initialize -- # -- Set in async_initialize --
ha_device_registry: dr.DeviceRegistry
ha_entity_registry: er.EntityRegistry
application_controller: ControllerApplication application_controller: ControllerApplication
radio_description: str radio_description: str
@ -132,7 +128,7 @@ class ZHAGateway:
self, hass: HomeAssistant, config: ConfigType, config_entry: ConfigEntry self, hass: HomeAssistant, config: ConfigType, config_entry: ConfigEntry
) -> None: ) -> None:
"""Initialize the gateway.""" """Initialize the gateway."""
self._hass = hass self.hass = hass
self._config = config self._config = config
self._devices: dict[EUI64, ZHADevice] = {} self._devices: dict[EUI64, ZHADevice] = {}
self._groups: dict[int, ZHAGroup] = {} self._groups: dict[int, ZHAGroup] = {}
@ -159,7 +155,7 @@ class ZHAGateway:
app_config = self._config.get(CONF_ZIGPY, {}) app_config = self._config.get(CONF_ZIGPY, {})
database = self._config.get( database = self._config.get(
CONF_DATABASE, CONF_DATABASE,
self._hass.config.path(DEFAULT_DATABASE_NAME), self.hass.config.path(DEFAULT_DATABASE_NAME),
) )
app_config[CONF_DATABASE] = database app_config[CONF_DATABASE] = database
app_config[CONF_DEVICE] = self.config_entry.data[CONF_DEVICE] app_config[CONF_DEVICE] = self.config_entry.data[CONF_DEVICE]
@ -191,11 +187,8 @@ class ZHAGateway:
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
"""Initialize controller and connect radio.""" """Initialize controller and connect radio."""
discovery.PROBE.initialize(self._hass) discovery.PROBE.initialize(self.hass)
discovery.GROUP_PROBE.initialize(self._hass) discovery.GROUP_PROBE.initialize(self.hass)
self.ha_device_registry = dr.async_get(self._hass)
self.ha_entity_registry = er.async_get(self._hass)
app_controller_cls, app_config = self.get_application_controller_data() app_controller_cls, app_config = self.get_application_controller_data()
self.application_controller = await app_controller_cls.new( self.application_controller = await app_controller_cls.new(
@ -225,8 +218,8 @@ class ZHAGateway:
else: else:
break break
self._hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self zha_data = get_zha_data(self.hass)
self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str(self.coordinator_ieee) zha_data.gateway = self
self.coordinator_zha_device = self._async_get_or_create_device( self.coordinator_zha_device = self._async_get_or_create_device(
self._find_coordinator_device(), restored=True self._find_coordinator_device(), restored=True
@ -301,7 +294,7 @@ class ZHAGateway:
# background the fetching of state for mains powered devices # background the fetching of state for mains powered devices
self.config_entry.async_create_background_task( self.config_entry.async_create_background_task(
self._hass, fetch_updated_state(), "zha.gateway-fetch_updated_state" self.hass, fetch_updated_state(), "zha.gateway-fetch_updated_state"
) )
def device_joined(self, device: zigpy.device.Device) -> None: def device_joined(self, device: zigpy.device.Device) -> None:
@ -311,7 +304,7 @@ class ZHAGateway:
address address
""" """
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: ZHA_GW_MSG_DEVICE_JOINED, ATTR_TYPE: ZHA_GW_MSG_DEVICE_JOINED,
@ -327,7 +320,7 @@ class ZHAGateway:
"""Handle a device initialization without quirks loaded.""" """Handle a device initialization without quirks loaded."""
manuf = device.manufacturer manuf = device.manufacturer
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: ZHA_GW_MSG_RAW_INIT, ATTR_TYPE: ZHA_GW_MSG_RAW_INIT,
@ -344,7 +337,7 @@ class ZHAGateway:
def device_initialized(self, device: zigpy.device.Device) -> None: def device_initialized(self, device: zigpy.device.Device) -> None:
"""Handle device joined and basic information discovered.""" """Handle device joined and basic information discovered."""
self._hass.async_create_task(self.async_device_initialized(device)) self.hass.async_create_task(self.async_device_initialized(device))
def device_left(self, device: zigpy.device.Device) -> None: def device_left(self, device: zigpy.device.Device) -> None:
"""Handle device leaving the network.""" """Handle device leaving the network."""
@ -359,7 +352,7 @@ class ZHAGateway:
zha_group.info("group_member_removed - endpoint: %s", endpoint) zha_group.info("group_member_removed - endpoint: %s", endpoint)
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED)
async_dispatcher_send( async_dispatcher_send(
self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}" self.hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}"
) )
def group_member_added( def group_member_added(
@ -371,7 +364,7 @@ class ZHAGateway:
zha_group.info("group_member_added - endpoint: %s", endpoint) zha_group.info("group_member_added - endpoint: %s", endpoint)
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED)
async_dispatcher_send( async_dispatcher_send(
self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}" self.hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}"
) )
if len(zha_group.members) == 2: if len(zha_group.members) == 2:
# we need to do this because there wasn't already # we need to do this because there wasn't already
@ -399,7 +392,7 @@ class ZHAGateway:
zha_group = self._groups.get(zigpy_group.group_id) zha_group = self._groups.get(zigpy_group.group_id)
if zha_group is not None: if zha_group is not None:
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: gateway_message_type, ATTR_TYPE: gateway_message_type,
@ -416,9 +409,11 @@ class ZHAGateway:
remove_tasks.append(entity_ref.remove_future) remove_tasks.append(entity_ref.remove_future)
if remove_tasks: if remove_tasks:
await asyncio.wait(remove_tasks) await asyncio.wait(remove_tasks)
reg_device = self.ha_device_registry.async_get(device.device_id)
device_registry = dr.async_get(self.hass)
reg_device = device_registry.async_get(device.device_id)
if reg_device is not None: if reg_device is not None:
self.ha_device_registry.async_remove_device(reg_device.id) device_registry.async_remove_device(reg_device.id)
def device_removed(self, device: zigpy.device.Device) -> None: def device_removed(self, device: zigpy.device.Device) -> None:
"""Handle device being removed from the network.""" """Handle device being removed from the network."""
@ -427,14 +422,14 @@ class ZHAGateway:
if zha_device is not None: if zha_device is not None:
device_info = zha_device.zha_device_info device_info = zha_device.zha_device_info
zha_device.async_cleanup_handles() zha_device.async_cleanup_handles()
async_dispatcher_send(self._hass, f"{SIGNAL_REMOVE}_{str(zha_device.ieee)}") async_dispatcher_send(self.hass, f"{SIGNAL_REMOVE}_{str(zha_device.ieee)}")
self._hass.async_create_task( self.hass.async_create_task(
self._async_remove_device(zha_device, entity_refs), self._async_remove_device(zha_device, entity_refs),
"ZHAGateway._async_remove_device", "ZHAGateway._async_remove_device",
) )
if device_info is not None: if device_info is not None:
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: ZHA_GW_MSG_DEVICE_REMOVED, ATTR_TYPE: ZHA_GW_MSG_DEVICE_REMOVED,
@ -488,9 +483,10 @@ class ZHAGateway:
] ]
# then we get all group entity entries tied to the coordinator # then we get all group entity entries tied to the coordinator
entity_registry = er.async_get(self.hass)
assert self.coordinator_zha_device assert self.coordinator_zha_device
all_group_entity_entries = er.async_entries_for_device( all_group_entity_entries = er.async_entries_for_device(
self.ha_entity_registry, entity_registry,
self.coordinator_zha_device.device_id, self.coordinator_zha_device.device_id,
include_disabled_entities=True, include_disabled_entities=True,
) )
@ -508,7 +504,7 @@ class ZHAGateway:
_LOGGER.debug( _LOGGER.debug(
"cleaning up entity registry entry for entity: %s", entry.entity_id "cleaning up entity registry entry for entity: %s", entry.entity_id
) )
self.ha_entity_registry.async_remove(entry.entity_id) entity_registry.async_remove(entry.entity_id)
@property @property
def coordinator_ieee(self) -> EUI64: def coordinator_ieee(self) -> EUI64:
@ -582,9 +578,11 @@ class ZHAGateway:
) -> ZHADevice: ) -> ZHADevice:
"""Get or create a ZHA device.""" """Get or create a ZHA device."""
if (zha_device := self._devices.get(zigpy_device.ieee)) is None: if (zha_device := self._devices.get(zigpy_device.ieee)) is None:
zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored) zha_device = ZHADevice.new(self.hass, zigpy_device, self, restored)
self._devices[zigpy_device.ieee] = zha_device self._devices[zigpy_device.ieee] = zha_device
device_registry_device = self.ha_device_registry.async_get_or_create(
device_registry = dr.async_get(self.hass)
device_registry_device = device_registry.async_get_or_create(
config_entry_id=self.config_entry.entry_id, config_entry_id=self.config_entry.entry_id,
connections={(dr.CONNECTION_ZIGBEE, str(zha_device.ieee))}, connections={(dr.CONNECTION_ZIGBEE, str(zha_device.ieee))},
identifiers={(DOMAIN, str(zha_device.ieee))}, identifiers={(DOMAIN, str(zha_device.ieee))},
@ -600,7 +598,7 @@ class ZHAGateway:
"""Get or create a ZHA group.""" """Get or create a ZHA group."""
zha_group = self._groups.get(zigpy_group.group_id) zha_group = self._groups.get(zigpy_group.group_id)
if zha_group is None: if zha_group is None:
zha_group = ZHAGroup(self._hass, self, zigpy_group) zha_group = ZHAGroup(self.hass, self, zigpy_group)
self._groups[zigpy_group.group_id] = zha_group self._groups[zigpy_group.group_id] = zha_group
return zha_group return zha_group
@ -645,7 +643,7 @@ class ZHAGateway:
device_info = zha_device.zha_device_info device_info = zha_device.zha_device_info
device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.INITIALIZED.name device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.INITIALIZED.name
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT, ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT,
@ -659,7 +657,7 @@ class ZHAGateway:
await zha_device.async_configure() await zha_device.async_configure()
device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.CONFIGURED.name device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.CONFIGURED.name
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT, ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT,
@ -667,7 +665,7 @@ class ZHAGateway:
}, },
) )
await zha_device.async_initialize(from_cache=False) await zha_device.async_initialize(from_cache=False)
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) async_dispatcher_send(self.hass, SIGNAL_ADD_ENTITIES)
async def _async_device_rejoined(self, zha_device: ZHADevice) -> None: async def _async_device_rejoined(self, zha_device: ZHADevice) -> None:
_LOGGER.debug( _LOGGER.debug(
@ -681,7 +679,7 @@ class ZHAGateway:
device_info = zha_device.device_info device_info = zha_device.device_info
device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.CONFIGURED.name device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.CONFIGURED.name
async_dispatcher_send( async_dispatcher_send(
self._hass, self.hass,
ZHA_GW_MSG, ZHA_GW_MSG,
{ {
ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT, ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT,

View File

@ -11,6 +11,7 @@ import zigpy.group
from zigpy.types.named import EUI64 from zigpy.types.named import EUI64
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_registry import async_entries_for_device from homeassistant.helpers.entity_registry import async_entries_for_device
from .helpers import LogMixin from .helpers import LogMixin
@ -32,8 +33,8 @@ class GroupMember(NamedTuple):
class GroupEntityReference(NamedTuple): class GroupEntityReference(NamedTuple):
"""Reference to a group entity.""" """Reference to a group entity."""
name: str name: str | None
original_name: str original_name: str | None
entity_id: int entity_id: int
@ -80,20 +81,30 @@ class ZHAGroupMember(LogMixin):
@property @property
def associated_entities(self) -> list[dict[str, Any]]: def associated_entities(self) -> list[dict[str, Any]]:
"""Return the list of entities that were derived from this endpoint.""" """Return the list of entities that were derived from this endpoint."""
ha_entity_registry = self.device.gateway.ha_entity_registry entity_registry = er.async_get(self._zha_device.hass)
zha_device_registry = self.device.gateway.device_registry zha_device_registry = self.device.gateway.device_registry
return [
entity_info = []
for entity_ref in zha_device_registry.get(self.device.ieee):
entity = entity_registry.async_get(entity_ref.reference_id)
handler = list(entity_ref.cluster_handlers.values())[0]
if (
entity is None
or handler.cluster.endpoint.endpoint_id != self.endpoint_id
):
continue
entity_info.append(
GroupEntityReference( GroupEntityReference(
ha_entity_registry.async_get(entity_ref.reference_id).name, name=entity.name,
ha_entity_registry.async_get(entity_ref.reference_id).original_name, original_name=entity.original_name,
entity_ref.reference_id, entity_id=entity_ref.reference_id,
)._asdict() )._asdict()
for entity_ref in zha_device_registry.get(self.device.ieee) )
if list(entity_ref.cluster_handlers.values())[
0 return entity_info
].cluster.endpoint.endpoint_id
== self.endpoint_id
]
async def async_remove_from_group(self) -> None: async def async_remove_from_group(self) -> None:
"""Remove the device endpoint from the provided zigbee group.""" """Remove the device endpoint from the provided zigbee group."""
@ -204,12 +215,14 @@ class ZHAGroup(LogMixin):
def get_domain_entity_ids(self, domain: str) -> list[str]: def get_domain_entity_ids(self, domain: str) -> list[str]:
"""Return entity ids from the entity domain for this group.""" """Return entity ids from the entity domain for this group."""
entity_registry = er.async_get(self.hass)
domain_entity_ids: list[str] = [] domain_entity_ids: list[str] = []
for member in self.members: for member in self.members:
if member.device.is_coordinator: if member.device.is_coordinator:
continue continue
entities = async_entries_for_device( entities = async_entries_for_device(
self._zha_gateway.ha_entity_registry, entity_registry,
member.device.device_id, member.device.device_id,
include_disabled_entities=True, include_disabled_entities=True,
) )

View File

@ -7,7 +7,9 @@ from __future__ import annotations
import asyncio import asyncio
import binascii import binascii
import collections
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
import enum import enum
import functools import functools
@ -26,16 +28,12 @@ from zigpy.zcl.foundation import CommandSchema
import zigpy.zdo.types as zdo_types import zigpy.zdo.types as zdo_types
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, CUSTOM_CONFIGURATION, DATA_ZHA
CLUSTER_TYPE_IN,
CLUSTER_TYPE_OUT,
CUSTOM_CONFIGURATION,
DATA_ZHA,
DATA_ZHA_GATEWAY,
)
from .registries import BINDABLE_CLUSTERS from .registries import BINDABLE_CLUSTERS
if TYPE_CHECKING: if TYPE_CHECKING:
@ -221,7 +219,7 @@ def async_get_zha_config_value(
def async_cluster_exists(hass, cluster_id, skip_coordinator=True): def async_cluster_exists(hass, cluster_id, skip_coordinator=True):
"""Determine if a device containing the specified in cluster is paired.""" """Determine if a device containing the specified in cluster is paired."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
zha_devices = zha_gateway.devices.values() zha_devices = zha_gateway.devices.values()
for zha_device in zha_devices: for zha_device in zha_devices:
if skip_coordinator and zha_device.is_coordinator: if skip_coordinator and zha_device.is_coordinator:
@ -244,7 +242,7 @@ def async_get_zha_device(hass: HomeAssistant, device_id: str) -> ZHADevice:
if not registry_device: if not registry_device:
_LOGGER.error("Device id `%s` not found in registry", device_id) _LOGGER.error("Device id `%s` not found in registry", device_id)
raise KeyError(f"Device id `{device_id}` not found in registry.") raise KeyError(f"Device id `{device_id}` not found in registry.")
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
try: try:
ieee_address = list(registry_device.identifiers)[0][1] ieee_address = list(registry_device.identifiers)[0][1]
ieee = zigpy.types.EUI64.convert(ieee_address) ieee = zigpy.types.EUI64.convert(ieee_address)
@ -421,3 +419,30 @@ def qr_to_install_code(qr_code: str) -> tuple[zigpy.types.EUI64, bytes]:
return ieee, install_code return ieee, install_code
raise vol.Invalid(f"couldn't convert qr code: {qr_code}") raise vol.Invalid(f"couldn't convert qr code: {qr_code}")
@dataclasses.dataclass(kw_only=True, slots=True)
class ZHAData:
"""ZHA component data stored in `hass.data`."""
yaml_config: ConfigType = dataclasses.field(default_factory=dict)
platforms: collections.defaultdict[Platform, list] = dataclasses.field(
default_factory=lambda: collections.defaultdict(list)
)
gateway: ZHAGateway | None = dataclasses.field(default=None)
device_trigger_cache: dict[str, tuple[str, dict]] = dataclasses.field(
default_factory=dict
)
def get_zha_data(hass: HomeAssistant) -> ZHAData:
"""Get the global ZHA data object."""
return hass.data.get(DATA_ZHA, ZHAData())
def get_zha_gateway(hass: HomeAssistant) -> ZHAGateway:
"""Get the ZHA gateway object."""
if (zha_gateway := get_zha_data(hass).gateway) is None:
raise ValueError("No gateway object exists")
return zha_gateway

View File

@ -269,15 +269,15 @@ class ZHAEntityRegistry:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize Registry instance.""" """Initialize Registry instance."""
self._strict_registry: dict[ self._strict_registry: dict[
str, dict[MatchRule, type[ZhaEntity]] Platform, dict[MatchRule, type[ZhaEntity]]
] = collections.defaultdict(dict) ] = collections.defaultdict(dict)
self._multi_entity_registry: dict[ self._multi_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] Platform, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]]
] = collections.defaultdict( ] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list)) lambda: collections.defaultdict(lambda: collections.defaultdict(list))
) )
self._config_diagnostic_entity_registry: dict[ self._config_diagnostic_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] Platform, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]]
] = collections.defaultdict( ] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list)) lambda: collections.defaultdict(lambda: collections.defaultdict(list))
) )
@ -288,7 +288,7 @@ class ZHAEntityRegistry:
def get_entity( def get_entity(
self, self,
component: str, component: Platform,
manufacturer: str, manufacturer: str,
model: str, model: str,
cluster_handlers: list[ClusterHandler], cluster_handlers: list[ClusterHandler],
@ -310,10 +310,12 @@ class ZHAEntityRegistry:
model: str, model: str,
cluster_handlers: list[ClusterHandler], cluster_handlers: list[ClusterHandler],
quirk_class: str, quirk_class: str,
) -> tuple[dict[str, list[EntityClassAndClusterHandlers]], list[ClusterHandler]]: ) -> tuple[
dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler]
]:
"""Match ZHA cluster handlers to potentially multiple ZHA Entity classes.""" """Match ZHA cluster handlers to potentially multiple ZHA Entity classes."""
result: dict[ result: dict[
str, list[EntityClassAndClusterHandlers] Platform, list[EntityClassAndClusterHandlers]
] = collections.defaultdict(list) ] = collections.defaultdict(list)
all_claimed: set[ClusterHandler] = set() all_claimed: set[ClusterHandler] = set()
for component, stop_match_groups in self._multi_entity_registry.items(): for component, stop_match_groups in self._multi_entity_registry.items():
@ -341,10 +343,12 @@ class ZHAEntityRegistry:
model: str, model: str,
cluster_handlers: list[ClusterHandler], cluster_handlers: list[ClusterHandler],
quirk_class: str, quirk_class: str,
) -> tuple[dict[str, list[EntityClassAndClusterHandlers]], list[ClusterHandler]]: ) -> tuple[
dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler]
]:
"""Match ZHA cluster handlers to potentially multiple ZHA Entity classes.""" """Match ZHA cluster handlers to potentially multiple ZHA Entity classes."""
result: dict[ result: dict[
str, list[EntityClassAndClusterHandlers] Platform, list[EntityClassAndClusterHandlers]
] = collections.defaultdict(list) ] = collections.defaultdict(list)
all_claimed: set[ClusterHandler] = set() all_claimed: set[ClusterHandler] = set()
for ( for (
@ -375,7 +379,7 @@ class ZHAEntityRegistry:
def strict_match( def strict_match(
self, self,
component: str, component: Platform,
cluster_handler_names: set[str] | str | None = None, cluster_handler_names: set[str] | str | None = None,
generic_ids: set[str] | str | None = None, generic_ids: set[str] | str | None = None,
manufacturers: Callable | set[str] | str | None = None, manufacturers: Callable | set[str] | str | None = None,
@ -406,7 +410,7 @@ class ZHAEntityRegistry:
def multipass_match( def multipass_match(
self, self,
component: str, component: Platform,
cluster_handler_names: set[str] | str | None = None, cluster_handler_names: set[str] | str | None = None,
generic_ids: set[str] | str | None = None, generic_ids: set[str] | str | None = None,
manufacturers: Callable | set[str] | str | None = None, manufacturers: Callable | set[str] | str | None = None,
@ -441,7 +445,7 @@ class ZHAEntityRegistry:
def config_diagnostic_match( def config_diagnostic_match(
self, self,
component: str, component: Platform,
cluster_handler_names: set[str] | str | None = None, cluster_handler_names: set[str] | str | None = None,
generic_ids: set[str] | str | None = None, generic_ids: set[str] | str | None = None,
manufacturers: Callable | set[str] | str | None = None, manufacturers: Callable | set[str] | str | None = None,
@ -475,7 +479,7 @@ class ZHAEntityRegistry:
return decorator return decorator
def group_match( def group_match(
self, component: str self, component: Platform
) -> Callable[[_ZhaGroupEntityT], _ZhaGroupEntityT]: ) -> Callable[[_ZhaGroupEntityT], _ZhaGroupEntityT]:
"""Decorate a group match rule.""" """Decorate a group match rule."""

View File

@ -33,11 +33,11 @@ from .core.const import (
CLUSTER_HANDLER_LEVEL, CLUSTER_HANDLER_LEVEL,
CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_ON_OFF,
CLUSTER_HANDLER_SHADE, CLUSTER_HANDLER_SHADE,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
SIGNAL_SET_LEVEL, SIGNAL_SET_LEVEL,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -56,7 +56,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation cover from config entry.""" """Set up the Zigbee Home Automation cover from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.COVER] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.COVER]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -15,10 +15,10 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .core import discovery from .core import discovery
from .core.const import ( from .core.const import (
CLUSTER_HANDLER_POWER_CONFIGURATION, CLUSTER_HANDLER_POWER_CONFIGURATION,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
from .sensor import Battery from .sensor import Battery
@ -32,7 +32,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation device tracker from config entry.""" """Set up the Zigbee Home Automation device tracker from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.DEVICE_TRACKER] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.DEVICE_TRACKER]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -14,8 +14,8 @@ from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import DOMAIN as ZHA_DOMAIN from . import DOMAIN as ZHA_DOMAIN
from .core.const import DATA_ZHA, DATA_ZHA_DEVICE_TRIGGER_CACHE, ZHA_EVENT from .core.const import ZHA_EVENT
from .core.helpers import async_get_zha_device from .core.helpers import async_get_zha_device, get_zha_data
CONF_SUBTYPE = "subtype" CONF_SUBTYPE = "subtype"
DEVICE = "device" DEVICE = "device"
@ -32,13 +32,13 @@ def _get_device_trigger_data(hass: HomeAssistant, device_id: str) -> tuple[str,
# First, try checking to see if the device itself is accessible # First, try checking to see if the device itself is accessible
try: try:
zha_device = async_get_zha_device(hass, device_id) zha_device = async_get_zha_device(hass, device_id)
except KeyError: except ValueError:
pass pass
else: else:
return str(zha_device.ieee), zha_device.device_automation_triggers return str(zha_device.ieee), zha_device.device_automation_triggers
# If not, check the trigger cache but allow any `KeyError`s to propagate # If not, check the trigger cache but allow any `KeyError`s to propagate
return hass.data[DATA_ZHA][DATA_ZHA_DEVICE_TRIGGER_CACHE][device_id] return get_zha_data(hass).device_trigger_cache[device_id]
async def async_validate_trigger_config( async def async_validate_trigger_config(

View File

@ -25,14 +25,10 @@ from .core.const import (
ATTR_PROFILE_ID, ATTR_PROFILE_ID,
ATTR_VALUE, ATTR_VALUE,
CONF_ALARM_MASTER_CODE, CONF_ALARM_MASTER_CODE,
DATA_ZHA,
DATA_ZHA_CONFIG,
DATA_ZHA_GATEWAY,
UNKNOWN, UNKNOWN,
) )
from .core.device import ZHADevice from .core.device import ZHADevice
from .core.gateway import ZHAGateway from .core.helpers import async_get_zha_device, get_zha_data, get_zha_gateway
from .core.helpers import async_get_zha_device
KEYS_TO_REDACT = { KEYS_TO_REDACT = {
ATTR_IEEE, ATTR_IEEE,
@ -66,18 +62,18 @@ async def async_get_config_entry_diagnostics(
hass: HomeAssistant, config_entry: ConfigEntry hass: HomeAssistant, config_entry: ConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
config: dict = hass.data[DATA_ZHA].get(DATA_ZHA_CONFIG, {}) zha_data = get_zha_data(hass)
gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] app = get_zha_gateway(hass).application_controller
energy_scan = await gateway.application_controller.energy_scan( energy_scan = await app.energy_scan(
channels=Channels.ALL_CHANNELS, duration_exp=4, count=1 channels=Channels.ALL_CHANNELS, duration_exp=4, count=1
) )
return async_redact_data( return async_redact_data(
{ {
"config": config, "config": zha_data.yaml_config,
"config_entry": config_entry.as_dict(), "config_entry": config_entry.as_dict(),
"application_state": shallow_asdict(gateway.application_controller.state), "application_state": shallow_asdict(app.state),
"energy_scan": { "energy_scan": {
channel: 100 * energy / 255 for channel, energy in energy_scan.items() channel: 100 * energy / 255 for channel, energy in energy_scan.items()
}, },

View File

@ -26,14 +26,12 @@ from homeassistant.helpers.typing import EventType
from .core.const import ( from .core.const import (
ATTR_MANUFACTURER, ATTR_MANUFACTURER,
ATTR_MODEL, ATTR_MODEL,
DATA_ZHA,
DATA_ZHA_BRIDGE_ID,
DOMAIN, DOMAIN,
SIGNAL_GROUP_ENTITY_REMOVED, SIGNAL_GROUP_ENTITY_REMOVED,
SIGNAL_GROUP_MEMBERSHIP_CHANGE, SIGNAL_GROUP_MEMBERSHIP_CHANGE,
SIGNAL_REMOVE, SIGNAL_REMOVE,
) )
from .core.helpers import LogMixin from .core.helpers import LogMixin, get_zha_gateway
if TYPE_CHECKING: if TYPE_CHECKING:
from .core.cluster_handlers import ClusterHandler from .core.cluster_handlers import ClusterHandler
@ -83,13 +81,16 @@ class BaseZhaEntity(LogMixin, entity.Entity):
"""Return a device description for device registry.""" """Return a device description for device registry."""
zha_device_info = self._zha_device.device_info zha_device_info = self._zha_device.device_info
ieee = zha_device_info["ieee"] ieee = zha_device_info["ieee"]
zha_gateway = get_zha_gateway(self.hass)
return DeviceInfo( return DeviceInfo(
connections={(CONNECTION_ZIGBEE, ieee)}, connections={(CONNECTION_ZIGBEE, ieee)},
identifiers={(DOMAIN, ieee)}, identifiers={(DOMAIN, ieee)},
manufacturer=zha_device_info[ATTR_MANUFACTURER], manufacturer=zha_device_info[ATTR_MANUFACTURER],
model=zha_device_info[ATTR_MODEL], model=zha_device_info[ATTR_MODEL],
name=zha_device_info[ATTR_NAME], name=zha_device_info[ATTR_NAME],
via_device=(DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]), via_device=(DOMAIN, zha_gateway.coordinator_ieee),
) )
@callback @callback

View File

@ -28,12 +28,8 @@ from homeassistant.util.percentage import (
from .core import discovery from .core import discovery
from .core.cluster_handlers import wrap_zigpy_exceptions from .core.cluster_handlers import wrap_zigpy_exceptions
from .core.const import ( from .core.const import CLUSTER_HANDLER_FAN, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED
CLUSTER_HANDLER_FAN, from .core.helpers import get_zha_data
DATA_ZHA,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
)
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity, ZhaGroupEntity from .entity import ZhaEntity, ZhaGroupEntity
@ -65,7 +61,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation fan from config entry.""" """Set up the Zigbee Home Automation fan from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.FAN] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.FAN]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -47,13 +47,12 @@ from .core.const import (
CONF_ENABLE_ENHANCED_LIGHT_TRANSITION, CONF_ENABLE_ENHANCED_LIGHT_TRANSITION,
CONF_ENABLE_LIGHT_TRANSITIONING_FLAG, CONF_ENABLE_LIGHT_TRANSITIONING_FLAG,
CONF_GROUP_MEMBERS_ASSUME_STATE, CONF_GROUP_MEMBERS_ASSUME_STATE,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
SIGNAL_SET_LEVEL, SIGNAL_SET_LEVEL,
ZHA_OPTIONS, ZHA_OPTIONS,
) )
from .core.helpers import LogMixin, async_get_zha_config_value from .core.helpers import LogMixin, async_get_zha_config_value, get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity, ZhaGroupEntity from .entity import ZhaEntity, ZhaGroupEntity
@ -97,7 +96,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation light from config entry.""" """Set up the Zigbee Home Automation light from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.LIGHT] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.LIGHT]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -20,10 +20,10 @@ from homeassistant.helpers.typing import StateType
from .core import discovery from .core import discovery
from .core.const import ( from .core.const import (
CLUSTER_HANDLER_DOORLOCK, CLUSTER_HANDLER_DOORLOCK,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -45,7 +45,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation Door Lock from config entry.""" """Set up the Zigbee Home Automation Door Lock from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.LOCK] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.LOCK]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -20,10 +20,10 @@ from .core.const import (
CLUSTER_HANDLER_COLOR, CLUSTER_HANDLER_COLOR,
CLUSTER_HANDLER_INOVELLI, CLUSTER_HANDLER_INOVELLI,
CLUSTER_HANDLER_LEVEL, CLUSTER_HANDLER_LEVEL,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -258,7 +258,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation Analog Output from config entry.""" """Set up the Zigbee Home Automation Analog Output from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.NUMBER] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.NUMBER]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -26,12 +26,11 @@ from .core.const import (
CONF_DATABASE, CONF_DATABASE,
CONF_RADIO_TYPE, CONF_RADIO_TYPE,
CONF_ZIGPY, CONF_ZIGPY,
DATA_ZHA,
DATA_ZHA_CONFIG,
DEFAULT_DATABASE_NAME, DEFAULT_DATABASE_NAME,
EZSP_OVERWRITE_EUI64, EZSP_OVERWRITE_EUI64,
RadioType, RadioType,
) )
from .core.helpers import get_zha_data
# Only the common radio types will be autoprobed, ordered by new device popularity. # Only the common radio types will be autoprobed, ordered by new device popularity.
# XBee takes too long to probe since it scans through all possible bauds and likely has # XBee takes too long to probe since it scans through all possible bauds and likely has
@ -145,7 +144,7 @@ class ZhaRadioManager:
"""Connect to the radio with the current config and then clean up.""" """Connect to the radio with the current config and then clean up."""
assert self.radio_type is not None assert self.radio_type is not None
config = self.hass.data.get(DATA_ZHA, {}).get(DATA_ZHA_CONFIG, {}) config = get_zha_data(self.hass).yaml_config
app_config = config.get(CONF_ZIGPY, {}).copy() app_config = config.get(CONF_ZIGPY, {}).copy()
database_path = config.get( database_path = config.get(

View File

@ -23,11 +23,11 @@ from .core.const import (
CLUSTER_HANDLER_IAS_WD, CLUSTER_HANDLER_IAS_WD,
CLUSTER_HANDLER_INOVELLI, CLUSTER_HANDLER_INOVELLI,
CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_ON_OFF,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
Strobe, Strobe,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -48,7 +48,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation siren from config entry.""" """Set up the Zigbee Home Automation siren from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.SELECT] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.SELECT]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -57,10 +57,10 @@ from .core.const import (
CLUSTER_HANDLER_SOIL_MOISTURE, CLUSTER_HANDLER_SOIL_MOISTURE,
CLUSTER_HANDLER_TEMPERATURE, CLUSTER_HANDLER_TEMPERATURE,
CLUSTER_HANDLER_THERMOSTAT, CLUSTER_HANDLER_THERMOSTAT,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -99,7 +99,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation sensor from config entry.""" """Set up the Zigbee Home Automation sensor from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.SENSOR] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.SENSOR]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -25,7 +25,6 @@ from .core import discovery
from .core.cluster_handlers.security import IasWd from .core.cluster_handlers.security import IasWd
from .core.const import ( from .core.const import (
CLUSTER_HANDLER_IAS_WD, CLUSTER_HANDLER_IAS_WD,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
WARNING_DEVICE_MODE_BURGLAR, WARNING_DEVICE_MODE_BURGLAR,
WARNING_DEVICE_MODE_EMERGENCY, WARNING_DEVICE_MODE_EMERGENCY,
@ -39,6 +38,7 @@ from .core.const import (
WARNING_DEVICE_STROBE_NO, WARNING_DEVICE_STROBE_NO,
Strobe, Strobe,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
@ -56,7 +56,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation siren from config entry.""" """Set up the Zigbee Home Automation siren from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.SIREN] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.SIREN]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -20,10 +20,10 @@ from .core.const import (
CLUSTER_HANDLER_BASIC, CLUSTER_HANDLER_BASIC,
CLUSTER_HANDLER_INOVELLI, CLUSTER_HANDLER_INOVELLI,
CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_ON_OFF,
DATA_ZHA,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
) )
from .core.helpers import get_zha_data
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity, ZhaGroupEntity from .entity import ZhaEntity, ZhaGroupEntity
@ -46,7 +46,8 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the Zigbee Home Automation switch from config entry.""" """Set up the Zigbee Home Automation switch from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.SWITCH] zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.SWITCH]
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,

View File

@ -16,6 +16,7 @@ import zigpy.zdo.types as zdo_types
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import entity_registry as er
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.service import async_register_admin_service
@ -52,8 +53,6 @@ from .core.const import (
CLUSTER_TYPE_IN, CLUSTER_TYPE_IN,
CLUSTER_TYPE_OUT, CLUSTER_TYPE_OUT,
CUSTOM_CONFIGURATION, CUSTOM_CONFIGURATION,
DATA_ZHA,
DATA_ZHA_GATEWAY,
DOMAIN, DOMAIN,
EZSP_OVERWRITE_EUI64, EZSP_OVERWRITE_EUI64,
GROUP_ID, GROUP_ID,
@ -77,6 +76,7 @@ from .core.helpers import (
cluster_command_schema_to_vol_schema, cluster_command_schema_to_vol_schema,
convert_install_code, convert_install_code,
get_matched_clusters, get_matched_clusters,
get_zha_gateway,
qr_to_install_code, qr_to_install_code,
) )
@ -301,7 +301,7 @@ async def websocket_permit_devices(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Permit ZHA zigbee devices.""" """Permit ZHA zigbee devices."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
duration: int = msg[ATTR_DURATION] duration: int = msg[ATTR_DURATION]
ieee: EUI64 | None = msg.get(ATTR_IEEE) ieee: EUI64 | None = msg.get(ATTR_IEEE)
@ -348,7 +348,7 @@ async def websocket_get_devices(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA devices.""" """Get ZHA devices."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
devices = [device.zha_device_info for device in zha_gateway.devices.values()] devices = [device.zha_device_info for device in zha_gateway.devices.values()]
connection.send_result(msg[ID], devices) connection.send_result(msg[ID], devices)
@ -357,7 +357,8 @@ async def websocket_get_devices(
def _get_entity_name( def _get_entity_name(
zha_gateway: ZHAGateway, entity_ref: EntityReference zha_gateway: ZHAGateway, entity_ref: EntityReference
) -> str | None: ) -> str | None:
entry = zha_gateway.ha_entity_registry.async_get(entity_ref.reference_id) entity_registry = er.async_get(zha_gateway.hass)
entry = entity_registry.async_get(entity_ref.reference_id)
return entry.name if entry else None return entry.name if entry else None
@ -365,7 +366,8 @@ def _get_entity_name(
def _get_entity_original_name( def _get_entity_original_name(
zha_gateway: ZHAGateway, entity_ref: EntityReference zha_gateway: ZHAGateway, entity_ref: EntityReference
) -> str | None: ) -> str | None:
entry = zha_gateway.ha_entity_registry.async_get(entity_ref.reference_id) entity_registry = er.async_get(zha_gateway.hass)
entry = entity_registry.async_get(entity_ref.reference_id)
return entry.original_name if entry else None return entry.original_name if entry else None
@ -376,7 +378,7 @@ async def websocket_get_groupable_devices(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA devices that can be grouped.""" """Get ZHA devices that can be grouped."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
devices = [device for device in zha_gateway.devices.values() if device.is_groupable] devices = [device for device in zha_gateway.devices.values() if device.is_groupable]
groupable_devices = [] groupable_devices = []
@ -414,7 +416,7 @@ async def websocket_get_groups(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA groups.""" """Get ZHA groups."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
groups = [group.group_info for group in zha_gateway.groups.values()] groups = [group.group_info for group in zha_gateway.groups.values()]
connection.send_result(msg[ID], groups) connection.send_result(msg[ID], groups)
@ -431,7 +433,7 @@ async def websocket_get_device(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA devices.""" """Get ZHA devices."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = msg[ATTR_IEEE] ieee: EUI64 = msg[ATTR_IEEE]
if not (zha_device := zha_gateway.devices.get(ieee)): if not (zha_device := zha_gateway.devices.get(ieee)):
@ -458,7 +460,7 @@ async def websocket_get_group(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA group.""" """Get ZHA group."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
group_id: int = msg[GROUP_ID] group_id: int = msg[GROUP_ID]
if not (zha_group := zha_gateway.groups.get(group_id)): if not (zha_group := zha_gateway.groups.get(group_id)):
@ -487,7 +489,7 @@ async def websocket_add_group(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Add a new ZHA group.""" """Add a new ZHA group."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
group_name: str = msg[GROUP_NAME] group_name: str = msg[GROUP_NAME]
group_id: int | None = msg.get(GROUP_ID) group_id: int | None = msg.get(GROUP_ID)
members: list[GroupMember] | None = msg.get(ATTR_MEMBERS) members: list[GroupMember] | None = msg.get(ATTR_MEMBERS)
@ -508,7 +510,7 @@ async def websocket_remove_groups(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Remove the specified ZHA groups.""" """Remove the specified ZHA groups."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
group_ids: list[int] = msg[GROUP_IDS] group_ids: list[int] = msg[GROUP_IDS]
if len(group_ids) > 1: if len(group_ids) > 1:
@ -535,7 +537,7 @@ async def websocket_add_group_members(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Add members to a ZHA group.""" """Add members to a ZHA group."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
group_id: int = msg[GROUP_ID] group_id: int = msg[GROUP_ID]
members: list[GroupMember] = msg[ATTR_MEMBERS] members: list[GroupMember] = msg[ATTR_MEMBERS]
@ -565,7 +567,7 @@ async def websocket_remove_group_members(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Remove members from a ZHA group.""" """Remove members from a ZHA group."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
group_id: int = msg[GROUP_ID] group_id: int = msg[GROUP_ID]
members: list[GroupMember] = msg[ATTR_MEMBERS] members: list[GroupMember] = msg[ATTR_MEMBERS]
@ -594,7 +596,7 @@ async def websocket_reconfigure_node(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Reconfigure a ZHA nodes entities by its ieee address.""" """Reconfigure a ZHA nodes entities by its ieee address."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = msg[ATTR_IEEE] ieee: EUI64 = msg[ATTR_IEEE]
device: ZHADevice | None = zha_gateway.get_device(ieee) device: ZHADevice | None = zha_gateway.get_device(ieee)
@ -629,7 +631,7 @@ async def websocket_update_topology(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Update the ZHA network topology.""" """Update the ZHA network topology."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
hass.async_create_task(zha_gateway.application_controller.topology.scan()) hass.async_create_task(zha_gateway.application_controller.topology.scan())
@ -645,7 +647,7 @@ async def websocket_device_clusters(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Return a list of device clusters.""" """Return a list of device clusters."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = msg[ATTR_IEEE] ieee: EUI64 = msg[ATTR_IEEE]
zha_device = zha_gateway.get_device(ieee) zha_device = zha_gateway.get_device(ieee)
response_clusters = [] response_clusters = []
@ -689,7 +691,7 @@ async def websocket_device_cluster_attributes(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Return a list of cluster attributes.""" """Return a list of cluster attributes."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = msg[ATTR_IEEE] ieee: EUI64 = msg[ATTR_IEEE]
endpoint_id: int = msg[ATTR_ENDPOINT_ID] endpoint_id: int = msg[ATTR_ENDPOINT_ID]
cluster_id: int = msg[ATTR_CLUSTER_ID] cluster_id: int = msg[ATTR_CLUSTER_ID]
@ -736,7 +738,7 @@ async def websocket_device_cluster_commands(
"""Return a list of cluster commands.""" """Return a list of cluster commands."""
import voluptuous_serialize # pylint: disable=import-outside-toplevel import voluptuous_serialize # pylint: disable=import-outside-toplevel
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = msg[ATTR_IEEE] ieee: EUI64 = msg[ATTR_IEEE]
endpoint_id: int = msg[ATTR_ENDPOINT_ID] endpoint_id: int = msg[ATTR_ENDPOINT_ID]
cluster_id: int = msg[ATTR_CLUSTER_ID] cluster_id: int = msg[ATTR_CLUSTER_ID]
@ -806,7 +808,7 @@ async def websocket_read_zigbee_cluster_attributes(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Read zigbee attribute for cluster on ZHA entity.""" """Read zigbee attribute for cluster on ZHA entity."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = msg[ATTR_IEEE] ieee: EUI64 = msg[ATTR_IEEE]
endpoint_id: int = msg[ATTR_ENDPOINT_ID] endpoint_id: int = msg[ATTR_ENDPOINT_ID]
cluster_id: int = msg[ATTR_CLUSTER_ID] cluster_id: int = msg[ATTR_CLUSTER_ID]
@ -860,7 +862,7 @@ async def websocket_get_bindable_devices(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Directly bind devices.""" """Directly bind devices."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
source_ieee: EUI64 = msg[ATTR_IEEE] source_ieee: EUI64 = msg[ATTR_IEEE]
source_device = zha_gateway.get_device(source_ieee) source_device = zha_gateway.get_device(source_ieee)
@ -894,7 +896,7 @@ async def websocket_bind_devices(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Directly bind devices.""" """Directly bind devices."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE]
target_ieee: EUI64 = msg[ATTR_TARGET_IEEE] target_ieee: EUI64 = msg[ATTR_TARGET_IEEE]
await async_binding_operation( await async_binding_operation(
@ -923,7 +925,7 @@ async def websocket_unbind_devices(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Remove a direct binding between devices.""" """Remove a direct binding between devices."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE]
target_ieee: EUI64 = msg[ATTR_TARGET_IEEE] target_ieee: EUI64 = msg[ATTR_TARGET_IEEE]
await async_binding_operation( await async_binding_operation(
@ -953,7 +955,7 @@ async def websocket_bind_group(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Directly bind a device to a group.""" """Directly bind a device to a group."""
zha_gateway: ZHAGateway = get_gateway(hass) zha_gateway = get_zha_gateway(hass)
source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE]
group_id: int = msg[GROUP_ID] group_id: int = msg[GROUP_ID]
bindings: list[ClusterBinding] = msg[BINDINGS] bindings: list[ClusterBinding] = msg[BINDINGS]
@ -977,7 +979,7 @@ async def websocket_unbind_group(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Unbind a device from a group.""" """Unbind a device from a group."""
zha_gateway: ZHAGateway = get_gateway(hass) zha_gateway = get_zha_gateway(hass)
source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE]
group_id: int = msg[GROUP_ID] group_id: int = msg[GROUP_ID]
bindings: list[ClusterBinding] = msg[BINDINGS] bindings: list[ClusterBinding] = msg[BINDINGS]
@ -987,11 +989,6 @@ async def websocket_unbind_group(
connection.send_result(msg[ID]) connection.send_result(msg[ID])
def get_gateway(hass: HomeAssistant) -> ZHAGateway:
"""Return Gateway, mainly as fixture for mocking during testing."""
return hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
async def async_binding_operation( async def async_binding_operation(
zha_gateway: ZHAGateway, zha_gateway: ZHAGateway,
source_ieee: EUI64, source_ieee: EUI64,
@ -1047,7 +1044,7 @@ async def websocket_get_configuration(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA configuration.""" """Get ZHA configuration."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
import voluptuous_serialize # pylint: disable=import-outside-toplevel import voluptuous_serialize # pylint: disable=import-outside-toplevel
def custom_serializer(schema: Any) -> Any: def custom_serializer(schema: Any) -> Any:
@ -1094,7 +1091,7 @@ async def websocket_update_zha_configuration(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Update the ZHA configuration.""" """Update the ZHA configuration."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
options = zha_gateway.config_entry.options options = zha_gateway.config_entry.options
data_to_save = {**options, **{CUSTOM_CONFIGURATION: msg["data"]}} data_to_save = {**options, **{CUSTOM_CONFIGURATION: msg["data"]}}
@ -1141,7 +1138,7 @@ async def websocket_get_network_settings(
) -> None: ) -> None:
"""Get ZHA network settings.""" """Get ZHA network settings."""
backup = async_get_active_network_settings(hass) backup = async_get_active_network_settings(hass)
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
connection.send_result( connection.send_result(
msg[ID], msg[ID],
{ {
@ -1159,7 +1156,7 @@ async def websocket_list_network_backups(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Get ZHA network settings.""" """Get ZHA network settings."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
application_controller = zha_gateway.application_controller application_controller = zha_gateway.application_controller
# Serialize known backups # Serialize known backups
@ -1175,7 +1172,7 @@ async def websocket_create_network_backup(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Create a ZHA network backup.""" """Create a ZHA network backup."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
application_controller = zha_gateway.application_controller application_controller = zha_gateway.application_controller
# This can take 5-30s # This can take 5-30s
@ -1202,7 +1199,7 @@ async def websocket_restore_network_backup(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None: ) -> None:
"""Restore a ZHA network backup.""" """Restore a ZHA network backup."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
application_controller = zha_gateway.application_controller application_controller = zha_gateway.application_controller
backup = msg["backup"] backup = msg["backup"]
@ -1240,7 +1237,7 @@ async def websocket_change_channel(
@callback @callback
def async_load_api(hass: HomeAssistant) -> None: def async_load_api(hass: HomeAssistant) -> None:
"""Set up the web socket API.""" """Set up the web socket API."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
application_controller = zha_gateway.application_controller application_controller = zha_gateway.application_controller
async def permit(service: ServiceCall) -> None: async def permit(service: ServiceCall) -> None:
@ -1278,7 +1275,7 @@ def async_load_api(hass: HomeAssistant) -> None:
async def remove(service: ServiceCall) -> None: async def remove(service: ServiceCall) -> None:
"""Remove a node from the network.""" """Remove a node from the network."""
zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
ieee: EUI64 = service.data[ATTR_IEEE] ieee: EUI64 = service.data[ATTR_IEEE]
zha_device: ZHADevice | None = zha_gateway.get_device(ieee) zha_device: ZHADevice | None = zha_gateway.get_device(ieee)
if zha_device is not None and zha_device.is_active_coordinator: if zha_device is not None and zha_device.is_active_coordinator:

View File

@ -9,7 +9,10 @@ import zigpy.zcl
import zigpy.zcl.foundation as zcl_f import zigpy.zcl.foundation as zcl_f
import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.helpers import async_get_zha_config_value from homeassistant.components.zha.core.helpers import (
async_get_zha_config_value,
get_zha_gateway,
)
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -85,11 +88,6 @@ def update_attribute_cache(cluster):
cluster.handle_message(hdr, msg) cluster.handle_message(hdr, msg)
def get_zha_gateway(hass):
"""Return ZHA gateway from hass.data."""
return hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
def make_attribute(attrid, value, status=0): def make_attribute(attrid, value, status=0):
"""Make an attribute.""" """Make an attribute."""
attr = zcl_f.Attribute() attr = zcl_f.Attribute()

View File

@ -22,9 +22,10 @@ import zigpy.zdo.types as zdo_t
import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.const as zha_const
import homeassistant.components.zha.core.device as zha_core_device import homeassistant.components.zha.core.device as zha_core_device
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import common from .common import patch_cluster as common_patch_cluster
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.light.conftest import mock_light_profiles # noqa: F401 from tests.components.light.conftest import mock_light_profiles # noqa: F401
@ -277,7 +278,7 @@ def zigpy_device_mock(zigpy_app_controller):
for cluster in itertools.chain( for cluster in itertools.chain(
endpoint.in_clusters.values(), endpoint.out_clusters.values() endpoint.in_clusters.values(), endpoint.out_clusters.values()
): ):
common.patch_cluster(cluster) common_patch_cluster(cluster)
if attributes is not None: if attributes is not None:
for ep_id, clusters in attributes.items(): for ep_id, clusters in attributes.items():
@ -305,7 +306,7 @@ def zha_device_joined(hass, setup_zha):
if setup_zha: if setup_zha:
await setup_zha_fixture() await setup_zha_fixture()
zha_gateway = common.get_zha_gateway(hass) zha_gateway = get_zha_gateway(hass)
zha_gateway.application_controller.devices[zigpy_dev.ieee] = zigpy_dev zha_gateway.application_controller.devices[zigpy_dev.ieee] = zigpy_dev
await zha_gateway.async_device_initialized(zigpy_dev) await zha_gateway.async_device_initialized(zigpy_dev)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -329,7 +330,7 @@ def zha_device_restored(hass, zigpy_app_controller, setup_zha):
if setup_zha: if setup_zha:
await setup_zha_fixture() await setup_zha_fixture()
zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] zha_gateway = get_zha_gateway(hass)
return zha_gateway.get_device(zigpy_dev.ieee) return zha_gateway.get_device(zigpy_dev.ieee)
return _zha_device return _zha_device

View File

@ -11,6 +11,7 @@ import zigpy.state
from homeassistant.components import zha from homeassistant.components import zha
from homeassistant.components.zha import api from homeassistant.components.zha import api
from homeassistant.components.zha.core.const import RadioType from homeassistant.components.zha.core.const import RadioType
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
if TYPE_CHECKING: if TYPE_CHECKING:
@ -40,7 +41,7 @@ async def test_async_get_network_settings_inactive(
"""Test reading settings with an inactive ZHA installation.""" """Test reading settings with an inactive ZHA installation."""
await setup_zha() await setup_zha()
gateway = api._get_gateway(hass) gateway = get_zha_gateway(hass)
await zha.async_unload_entry(hass, gateway.config_entry) await zha.async_unload_entry(hass, gateway.config_entry)
backup = zigpy.backups.NetworkBackup() backup = zigpy.backups.NetworkBackup()
@ -70,7 +71,7 @@ async def test_async_get_network_settings_missing(
"""Test reading settings with an inactive ZHA installation, no valid channel.""" """Test reading settings with an inactive ZHA installation, no valid channel."""
await setup_zha() await setup_zha()
gateway = api._get_gateway(hass) gateway = get_zha_gateway(hass)
await gateway.config_entry.async_unload(hass) await gateway.config_entry.async_unload(hass)
# Network settings were never loaded for whatever reason # Network settings were never loaded for whatever reason

View File

@ -20,11 +20,12 @@ import homeassistant.components.zha.core.cluster_handlers as cluster_handlers
import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.endpoint import Endpoint from homeassistant.components.zha.core.endpoint import Endpoint
from homeassistant.components.zha.core.helpers import get_zha_gateway
import homeassistant.components.zha.core.registries as registries import homeassistant.components.zha.core.registries as registries
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .common import get_zha_gateway, make_zcl_header from .common import make_zcl_header
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
from tests.common import async_capture_events from tests.common import async_capture_events

View File

@ -108,21 +108,19 @@ async def test_get_actions(hass: HomeAssistant, device_ias) -> None:
ieee_address = str(device_ias[0].ieee) ieee_address = str(device_ias[0].ieee)
ha_device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
reg_device = ha_device_registry.async_get_device( reg_device = device_registry.async_get_device(identifiers={(DOMAIN, ieee_address)})
identifiers={(DOMAIN, ieee_address)} entity_registry = er.async_get(hass)
) siren_level_select = entity_registry.async_get(
ha_entity_registry = er.async_get(hass)
siren_level_select = ha_entity_registry.async_get(
"select.fakemanufacturer_fakemodel_default_siren_level" "select.fakemanufacturer_fakemodel_default_siren_level"
) )
siren_tone_select = ha_entity_registry.async_get( siren_tone_select = entity_registry.async_get(
"select.fakemanufacturer_fakemodel_default_siren_tone" "select.fakemanufacturer_fakemodel_default_siren_tone"
) )
strobe_level_select = ha_entity_registry.async_get( strobe_level_select = entity_registry.async_get(
"select.fakemanufacturer_fakemodel_default_strobe_level" "select.fakemanufacturer_fakemodel_default_strobe_level"
) )
strobe_select = ha_entity_registry.async_get( strobe_select = entity_registry.async_get(
"select.fakemanufacturer_fakemodel_default_strobe" "select.fakemanufacturer_fakemodel_default_strobe"
) )
@ -171,13 +169,13 @@ async def test_get_inovelli_actions(hass: HomeAssistant, device_inovelli) -> Non
"""Test we get the expected actions from a ZHA device.""" """Test we get the expected actions from a ZHA device."""
inovelli_ieee_address = str(device_inovelli[0].ieee) inovelli_ieee_address = str(device_inovelli[0].ieee)
ha_device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
inovelli_reg_device = ha_device_registry.async_get_device( inovelli_reg_device = device_registry.async_get_device(
identifiers={(DOMAIN, inovelli_ieee_address)} identifiers={(DOMAIN, inovelli_ieee_address)}
) )
ha_entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
inovelli_button = ha_entity_registry.async_get("button.inovelli_vzm31_sn_identify") inovelli_button = entity_registry.async_get("button.inovelli_vzm31_sn_identify")
inovelli_light = ha_entity_registry.async_get("light.inovelli_vzm31_sn_light") inovelli_light = entity_registry.async_get("light.inovelli_vzm31_sn_light")
actions = await async_get_device_automations( actions = await async_get_device_automations(
hass, DeviceAutomationType.ACTION, inovelli_reg_device.id hass, DeviceAutomationType.ACTION, inovelli_reg_device.id
@ -262,11 +260,9 @@ async def test_action(hass: HomeAssistant, device_ias, device_inovelli) -> None:
ieee_address = str(zha_device.ieee) ieee_address = str(zha_device.ieee)
inovelli_ieee_address = str(inovelli_zha_device.ieee) inovelli_ieee_address = str(inovelli_zha_device.ieee)
ha_device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
reg_device = ha_device_registry.async_get_device( reg_device = device_registry.async_get_device(identifiers={(DOMAIN, ieee_address)})
identifiers={(DOMAIN, ieee_address)} inovelli_reg_device = device_registry.async_get_device(
)
inovelli_reg_device = ha_device_registry.async_get_device(
identifiers={(DOMAIN, inovelli_ieee_address)} identifiers={(DOMAIN, inovelli_ieee_address)}
) )

View File

@ -477,6 +477,7 @@ async def test_validate_trigger_config_unloaded_bad_info(
# Reload ZHA to persist the device info in the cache # Reload ZHA to persist the device info in the cache
await hass.config_entries.async_setup(config_entry.entry_id) await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
await hass.config_entries.async_unload(config_entry.entry_id) await hass.config_entries.async_unload(config_entry.entry_id)
ha_device_registry = dr.async_get(hass) ha_device_registry = dr.async_get(hass)

View File

@ -6,8 +6,8 @@ import zigpy.profiles.zha as zha
import zigpy.zcl.clusters.security as security import zigpy.zcl.clusters.security as security
from homeassistant.components.diagnostics import REDACTED from homeassistant.components.diagnostics import REDACTED
from homeassistant.components.zha.core.const import DATA_ZHA, DATA_ZHA_GATEWAY
from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.components.zha.diagnostics import KEYS_TO_REDACT from homeassistant.components.zha.diagnostics import KEYS_TO_REDACT
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -65,7 +65,7 @@ async def test_diagnostics_for_config_entry(
"""Test diagnostics for config entry.""" """Test diagnostics for config entry."""
await zha_device_joined(zigpy_device) await zha_device_joined(zigpy_device)
gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] gateway = get_zha_gateway(hass)
scan = {c: c for c in range(11, 26 + 1)} scan = {c: c for c in range(11, 26 + 1)}
with patch.object(gateway.application_controller, "energy_scan", return_value=scan): with patch.object(gateway.application_controller, "energy_scan", return_value=scan):

View File

@ -20,12 +20,12 @@ import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.device import ZHADevice
import homeassistant.components.zha.core.discovery as disc import homeassistant.components.zha.core.discovery as disc
from homeassistant.components.zha.core.endpoint import Endpoint from homeassistant.components.zha.core.endpoint import Endpoint
from homeassistant.components.zha.core.helpers import get_zha_gateway
import homeassistant.components.zha.core.registries as zha_regs import homeassistant.components.zha.core.registries as zha_regs
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er import homeassistant.helpers.entity_registry as er
from .common import get_zha_gateway
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from .zha_devices_list import ( from .zha_devices_list import (
DEV_SIG_ATTRIBUTES, DEV_SIG_ATTRIBUTES,

View File

@ -21,6 +21,7 @@ from homeassistant.components.fan import (
from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.discovery import GROUP_PROBE from homeassistant.components.zha.core.discovery import GROUP_PROBE
from homeassistant.components.zha.core.group import GroupMember from homeassistant.components.zha.core.group import GroupMember
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.components.zha.fan import ( from homeassistant.components.zha.fan import (
PRESET_MODE_AUTO, PRESET_MODE_AUTO,
PRESET_MODE_ON, PRESET_MODE_ON,
@ -45,7 +46,6 @@ from .common import (
async_test_rejoin, async_test_rejoin,
async_wait_for_updates, async_wait_for_updates,
find_entity_id, find_entity_id,
get_zha_gateway,
send_attributes_report, send_attributes_report,
) )
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE

View File

@ -11,11 +11,12 @@ import zigpy.zcl.clusters.lighting as lighting
from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.group import GroupMember from homeassistant.components.zha.core.group import GroupMember
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from .common import async_find_group_entity_id, get_zha_gateway from .common import async_find_group_entity_id
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"

View File

@ -20,9 +20,11 @@ from homeassistant.components.zha.core.const import (
ZHA_OPTIONS, ZHA_OPTIONS,
) )
from homeassistant.components.zha.core.group import GroupMember from homeassistant.components.zha.core.group import GroupMember
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.components.zha.light import FLASH_EFFECTS from homeassistant.components.zha.light import FLASH_EFFECTS
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .common import ( from .common import (
@ -32,7 +34,6 @@ from .common import (
async_test_rejoin, async_test_rejoin,
async_wait_for_updates, async_wait_for_updates,
find_entity_id, find_entity_id,
get_zha_gateway,
patch_zha_config, patch_zha_config,
send_attributes_report, send_attributes_report,
update_attribute_cache, update_attribute_cache,
@ -1781,7 +1782,8 @@ async def test_zha_group_light_entity(
assert device_3_entity_id not in zha_group.member_entity_ids assert device_3_entity_id not in zha_group.member_entity_ids
# make sure the entity registry entry is still there # make sure the entity registry entry is still there
assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is not None entity_registry = er.async_get(hass)
assert entity_registry.async_get(group_entity_id) is not None
# add a member back and ensure that the group entity was created again # add a member back and ensure that the group entity was created again
await zha_group.async_add_members([GroupMember(device_light_3.ieee, 1)]) await zha_group.async_add_members([GroupMember(device_light_3.ieee, 1)])
@ -1811,10 +1813,10 @@ async def test_zha_group_light_entity(
assert len(zha_group.members) == 3 assert len(zha_group.members) == 3
# remove the group and ensure that there is no entity and that the entity registry is cleaned up # remove the group and ensure that there is no entity and that the entity registry is cleaned up
assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is not None assert entity_registry.async_get(group_entity_id) is not None
await zha_gateway.async_remove_zigpy_group(zha_group.group_id) await zha_gateway.async_remove_zigpy_group(zha_group.group_id)
assert hass.states.get(group_entity_id) is None assert hass.states.get(group_entity_id) is None
assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is None assert entity_registry.async_get(group_entity_id) is None
@patch( @patch(
@ -1914,7 +1916,8 @@ async def test_group_member_assume_state(
assert hass.states.get(group_entity_id).state == STATE_OFF assert hass.states.get(group_entity_id).state == STATE_OFF
# remove the group and ensure that there is no entity and that the entity registry is cleaned up # remove the group and ensure that there is no entity and that the entity registry is cleaned up
assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is not None entity_registry = er.async_get(hass)
assert entity_registry.async_get(group_entity_id) is not None
await zha_gateway.async_remove_zigpy_group(zha_group.group_id) await zha_gateway.async_remove_zigpy_group(zha_group.group_id)
assert hass.states.get(group_entity_id) is None assert hass.states.get(group_entity_id) is None
assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is None assert entity_registry.async_get(group_entity_id) is None

View File

@ -9,7 +9,8 @@ import zigpy.backups
import zigpy.state import zigpy.state
from homeassistant.components import zha from homeassistant.components import zha
from homeassistant.components.zha import api, silabs_multiprotocol from homeassistant.components.zha import silabs_multiprotocol
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
if TYPE_CHECKING: if TYPE_CHECKING:
@ -36,7 +37,7 @@ async def test_async_get_channel_missing(
"""Test reading channel with an inactive ZHA installation, no valid channel.""" """Test reading channel with an inactive ZHA installation, no valid channel."""
await setup_zha() await setup_zha()
gateway = api._get_gateway(hass) gateway = get_zha_gateway(hass)
await zha.async_unload_entry(hass, gateway.config_entry) await zha.async_unload_entry(hass, gateway.config_entry)
# Network settings were never loaded for whatever reason # Network settings were never loaded for whatever reason

View File

@ -19,6 +19,7 @@ import zigpy.zcl.foundation as zcl_f
from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.components.zha.core.group import GroupMember from homeassistant.components.zha.core.group import GroupMember
from homeassistant.components.zha.core.helpers import get_zha_gateway
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -30,7 +31,6 @@ from .common import (
async_test_rejoin, async_test_rejoin,
async_wait_for_updates, async_wait_for_updates,
find_entity_id, find_entity_id,
get_zha_gateway,
send_attributes_report, send_attributes_report,
) )
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE

View File

@ -940,6 +940,7 @@ async def test_websocket_bind_unbind_devices(
@pytest.mark.parametrize("command_type", ["bind", "unbind"]) @pytest.mark.parametrize("command_type", ["bind", "unbind"])
async def test_websocket_bind_unbind_group( async def test_websocket_bind_unbind_group(
command_type: str, command_type: str,
hass: HomeAssistant,
app_controller: ControllerApplication, app_controller: ControllerApplication,
zha_client, zha_client,
) -> None: ) -> None:
@ -947,8 +948,9 @@ async def test_websocket_bind_unbind_group(
test_group_id = 0x0001 test_group_id = 0x0001
gateway_mock = MagicMock() gateway_mock = MagicMock()
with patch( with patch(
"homeassistant.components.zha.websocket_api.get_gateway", "homeassistant.components.zha.websocket_api.get_zha_gateway",
return_value=gateway_mock, return_value=gateway_mock,
): ):
device_mock = MagicMock() device_mock = MagicMock()