diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index 662ddd080e0..bd181d82a33 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -33,9 +33,6 @@ from .core.const import ( CONF_USB_PATH, CONF_ZIGPY, DATA_ZHA, - DATA_ZHA_CONFIG, - DATA_ZHA_DEVICE_TRIGGER_CACHE, - DATA_ZHA_GATEWAY, DOMAIN, PLATFORMS, SIGNAL_ADD_ENTITIES, @@ -43,6 +40,7 @@ from .core.const import ( ) from .core.device import get_device_automation_triggers from .core.discovery import GROUP_PROBE +from .core.helpers import ZHAData, get_zha_data from .radio_manager import ZhaRadioManager DEVICE_CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_TYPE): cv.string}) @@ -81,11 +79,9 @@ _LOGGER = logging.getLogger(__name__) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up ZHA from config.""" - hass.data[DATA_ZHA] = {} - - if DOMAIN in config: - conf = config[DOMAIN] - hass.data[DATA_ZHA][DATA_ZHA_CONFIG] = conf + zha_data = ZHAData() + zha_data.yaml_config = config.get(DOMAIN, {}) + hass.data[DATA_ZHA] = zha_data return True @@ -120,14 +116,12 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b data[CONF_DEVICE][CONF_DEVICE_PATH] = cleaned_path hass.config_entries.async_update_entry(config_entry, data=data) - zha_data = hass.data.setdefault(DATA_ZHA, {}) - config = zha_data.get(DATA_ZHA_CONFIG, {}) + zha_data = get_zha_data(hass) - for platform in PLATFORMS: - zha_data.setdefault(platform, []) - - if config.get(CONF_ENABLE_QUIRKS, True): - setup_quirks(custom_quirks_path=config.get(CONF_CUSTOM_QUIRKS_PATH)) + if zha_data.yaml_config.get(CONF_ENABLE_QUIRKS, True): + setup_quirks( + custom_quirks_path=zha_data.yaml_config.get(CONF_CUSTOM_QUIRKS_PATH) + ) # temporary code to remove the ZHA storage file from disk. # this will be removed in 2022.10.0 @@ -139,8 +133,6 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b _LOGGER.debug("ZHA storage file does not exist or was already removed") # Load and cache device trigger information early - zha_data.setdefault(DATA_ZHA_DEVICE_TRIGGER_CACHE, {}) - device_registry = dr.async_get(hass) radio_mgr = ZhaRadioManager.from_config_entry(hass, config_entry) @@ -154,14 +146,14 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b if dev_entry is None: continue - zha_data[DATA_ZHA_DEVICE_TRIGGER_CACHE][dev_entry.id] = ( + zha_data.device_trigger_cache[dev_entry.id] = ( str(dev.ieee), get_device_automation_triggers(dev), ) - _LOGGER.debug("Trigger cache: %s", zha_data[DATA_ZHA_DEVICE_TRIGGER_CACHE]) + _LOGGER.debug("Trigger cache: %s", zha_data.device_trigger_cache) - zha_gateway = ZHAGateway(hass, config, config_entry) + zha_gateway = ZHAGateway(hass, zha_data.yaml_config, config_entry) async def async_zha_shutdown(): """Handle shutdown tasks.""" @@ -172,7 +164,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b # be in when we get here in failure cases with contextlib.suppress(KeyError): for platform in PLATFORMS: - del hass.data[DATA_ZHA][platform] + del zha_data.platforms[platform] config_entry.async_on_unload(async_zha_shutdown) @@ -212,10 +204,8 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Unload ZHA config entry.""" - try: - del hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - except KeyError: - return False + zha_data = get_zha_data(hass) + zha_data.gateway = None GROUP_PROBE.cleanup() websocket_api.async_unload_api(hass) @@ -241,7 +231,7 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> CONF_DEVICE: {CONF_DEVICE_PATH: config_entry.data[CONF_USB_PATH]}, } - baudrate = hass.data[DATA_ZHA].get(DATA_ZHA_CONFIG, {}).get(CONF_BAUDRATE) + baudrate = get_zha_data(hass).yaml_config.get(CONF_BAUDRATE) if data[CONF_RADIO_TYPE] != RadioType.deconz and baudrate in BAUD_RATES: data[CONF_DEVICE][CONF_BAUDRATE] = baudrate diff --git a/homeassistant/components/zha/alarm_control_panel.py b/homeassistant/components/zha/alarm_control_panel.py index b6794e909d8..21cacfa5dd4 100644 --- a/homeassistant/components/zha/alarm_control_panel.py +++ b/homeassistant/components/zha/alarm_control_panel.py @@ -35,11 +35,10 @@ from .core.const import ( CONF_ALARM_ARM_REQUIRES_CODE, CONF_ALARM_FAILED_TRIES, CONF_ALARM_MASTER_CODE, - DATA_ZHA, SIGNAL_ADD_ENTITIES, ZHA_ALARM_OPTIONS, ) -from .core.helpers import async_get_zha_config_value +from .core.helpers import async_get_zha_config_value, get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -65,7 +64,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation alarm control panel from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.ALARM_CONTROL_PANEL] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.ALARM_CONTROL_PANEL] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/api.py b/homeassistant/components/zha/api.py index 3d44103e225..f63fb9d09de 100644 --- a/homeassistant/components/zha/api.py +++ b/homeassistant/components/zha/api.py @@ -9,33 +9,22 @@ from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH from zigpy.types import Channels from zigpy.util import pick_optimal_channel -from .core.const import ( - CONF_RADIO_TYPE, - DATA_ZHA, - DATA_ZHA_CONFIG, - DATA_ZHA_GATEWAY, - DOMAIN, - RadioType, -) +from .core.const import CONF_RADIO_TYPE, DOMAIN, RadioType from .core.gateway import ZHAGateway +from .core.helpers import get_zha_data, get_zha_gateway if TYPE_CHECKING: from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -def _get_gateway(hass: HomeAssistant) -> ZHAGateway: - """Get a reference to the ZHA gateway device.""" - return hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - - def _get_config_entry(hass: HomeAssistant) -> ConfigEntry: """Find the singleton ZHA config entry, if one exists.""" # If ZHA is already running, use its config entry try: - zha_gateway = _get_gateway(hass) - except KeyError: + zha_gateway = get_zha_gateway(hass) + except ValueError: pass else: return zha_gateway.config_entry @@ -51,8 +40,7 @@ def _get_config_entry(hass: HomeAssistant) -> ConfigEntry: def async_get_active_network_settings(hass: HomeAssistant) -> NetworkBackup: """Get the network settings for the currently active ZHA network.""" - zha_gateway: ZHAGateway = _get_gateway(hass) - app = zha_gateway.application_controller + app = get_zha_gateway(hass).application_controller return NetworkBackup( node_info=app.state.node_info, @@ -67,7 +55,7 @@ async def async_get_last_network_settings( if config_entry is None: config_entry = _get_config_entry(hass) - config = hass.data.get(DATA_ZHA, {}).get(DATA_ZHA_CONFIG, {}) + config = get_zha_data(hass).yaml_config zha_gateway = ZHAGateway(hass, config, config_entry) app_controller_cls, app_config = zha_gateway.get_application_controller_data() @@ -91,7 +79,7 @@ async def async_get_network_settings( try: return async_get_active_network_settings(hass) - except KeyError: + except ValueError: return await async_get_last_network_settings(hass, config_entry) @@ -120,8 +108,7 @@ async def async_change_channel( ) -> None: """Migrate the ZHA network to a new channel.""" - zha_gateway: ZHAGateway = _get_gateway(hass) - app = zha_gateway.application_controller + app = get_zha_gateway(hass).application_controller if new_channel == "auto": channel_energy = await app.energy_scan( diff --git a/homeassistant/components/zha/backup.py b/homeassistant/components/zha/backup.py index 89d5294e1c4..e125a8085f6 100644 --- a/homeassistant/components/zha/backup.py +++ b/homeassistant/components/zha/backup.py @@ -3,8 +3,7 @@ import logging from homeassistant.core import HomeAssistant -from .core import ZHAGateway -from .core.const import DATA_ZHA, DATA_ZHA_GATEWAY +from .core.helpers import get_zha_gateway _LOGGER = logging.getLogger(__name__) @@ -13,7 +12,7 @@ async def async_pre_backup(hass: HomeAssistant) -> None: """Perform operations before a backup starts.""" _LOGGER.debug("Performing coordinator backup") - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) await zha_gateway.application_controller.backups.create_backup(load_devices=True) diff --git a/homeassistant/components/zha/binary_sensor.py b/homeassistant/components/zha/binary_sensor.py index 50cfb783370..c32bd5eeb67 100644 --- a/homeassistant/components/zha/binary_sensor.py +++ b/homeassistant/components/zha/binary_sensor.py @@ -26,10 +26,10 @@ from .core.const import ( CLUSTER_HANDLER_OCCUPANCY, CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_ZONE, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -65,7 +65,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation binary sensor from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.BINARY_SENSOR] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.BINARY_SENSOR] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/button.py b/homeassistant/components/zha/button.py index 7a4132115b8..4114a3dea7c 100644 --- a/homeassistant/components/zha/button.py +++ b/homeassistant/components/zha/button.py @@ -14,7 +14,8 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback from .core import discovery -from .core.const import CLUSTER_HANDLER_IDENTIFY, DATA_ZHA, SIGNAL_ADD_ENTITIES +from .core.const import CLUSTER_HANDLER_IDENTIFY, SIGNAL_ADD_ENTITIES +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -38,7 +39,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation button from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.BUTTON] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.BUTTON] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/climate.py b/homeassistant/components/zha/climate.py index cf868ef8b7b..5cbe2684ab4 100644 --- a/homeassistant/components/zha/climate.py +++ b/homeassistant/components/zha/climate.py @@ -45,13 +45,13 @@ from .core import discovery from .core.const import ( CLUSTER_HANDLER_FAN, CLUSTER_HANDLER_THERMOSTAT, - DATA_ZHA, PRESET_COMPLEX, PRESET_SCHEDULE, PRESET_TEMP_MANUAL, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -115,7 +115,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation sensor from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.CLIMATE] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.CLIMATE] unsub = async_dispatcher_connect( hass, SIGNAL_ADD_ENTITIES, diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index 9569fc49659..b37fa7ffe6d 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -184,7 +184,6 @@ CUSTOM_CONFIGURATION = "custom_configuration" DATA_DEVICE_CONFIG = "zha_device_config" DATA_ZHA = "zha" DATA_ZHA_CONFIG = "config" -DATA_ZHA_BRIDGE_ID = "zha_bridge_id" DATA_ZHA_CORE_EVENTS = "zha_core_events" DATA_ZHA_DEVICE_TRIGGER_CACHE = "zha_device_trigger_cache" DATA_ZHA_GATEWAY = "zha_gateway" diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index 60bf78e516c..8f5b087f068 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -25,6 +25,7 @@ from homeassistant.backports.functools import cached_property from homeassistant.const import ATTR_COMMAND, ATTR_DEVICE_ID, ATTR_NAME from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -420,7 +421,9 @@ class ZHADevice(LogMixin): """Update device sw version.""" if self.device_id is None: return - self._zha_gateway.ha_device_registry.async_update_device( + + device_registry = dr.async_get(self.hass) + device_registry.async_update_device( self.device_id, sw_version=f"0x{sw_version:08x}" ) @@ -658,7 +661,8 @@ class ZHADevice(LogMixin): ) device_info[ATTR_ENDPOINT_NAMES] = names - reg_device = self.gateway.ha_device_registry.async_get(self.device_id) + device_registry = dr.async_get(self.hass) + reg_device = device_registry.async_get(self.device_id) if reg_device is not None: device_info["user_given_name"] = reg_device.name_by_user device_info["device_reg_id"] = reg_device.id diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 92b68bdb159..a56e7044d3a 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -4,10 +4,11 @@ from __future__ import annotations from collections import Counter from collections.abc import Callable import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from homeassistant.const import CONF_TYPE, Platform from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -49,12 +50,12 @@ from .cluster_handlers import ( # noqa: F401 security, smartenergy, ) +from .helpers import get_zha_data, get_zha_gateway if TYPE_CHECKING: from ..entity import ZhaEntity from .device import ZHADevice from .endpoint import Endpoint - from .gateway import ZHAGateway from .group import ZHAGroup _LOGGER = logging.getLogger(__name__) @@ -113,6 +114,8 @@ class ProbeEndpoint: platform = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type) if platform and platform in zha_const.PLATFORMS: + platform = cast(Platform, platform) + cluster_handlers = endpoint.unclaimed_cluster_handlers() platform_entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity( platform, @@ -263,9 +266,7 @@ class ProbeEndpoint: def initialize(self, hass: HomeAssistant) -> None: """Update device overrides config.""" - zha_config: ConfigType = hass.data[zha_const.DATA_ZHA].get( - zha_const.DATA_ZHA_CONFIG, {} - ) + zha_config = get_zha_data(hass).yaml_config if overrides := zha_config.get(zha_const.CONF_DEVICE_CONFIG): self._device_configs.update(overrides) @@ -297,9 +298,7 @@ class GroupProbe: @callback def _reprobe_group(self, group_id: int) -> None: """Reprobe a group for entities after its members change.""" - zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][ - zha_const.DATA_ZHA_GATEWAY - ] + zha_gateway = get_zha_gateway(self._hass) if (zha_group := zha_gateway.groups.get(group_id)) is None: return self.discover_group_entities(zha_group) @@ -321,14 +320,14 @@ class GroupProbe: if not entity_domains: return - zha_gateway: ZHAGateway = self._hass.data[zha_const.DATA_ZHA][ - zha_const.DATA_ZHA_GATEWAY - ] + zha_data = get_zha_data(self._hass) + zha_gateway = get_zha_gateway(self._hass) + for domain in entity_domains: entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain) if entity_class is None: continue - self._hass.data[zha_const.DATA_ZHA][domain].append( + zha_data.platforms[domain].append( ( entity_class, ( @@ -342,24 +341,26 @@ class GroupProbe: async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES) @staticmethod - def determine_entity_domains(hass: HomeAssistant, group: ZHAGroup) -> list[str]: + def determine_entity_domains( + hass: HomeAssistant, group: ZHAGroup + ) -> list[Platform]: """Determine the entity domains for this group.""" - entity_domains: list[str] = [] - zha_gateway: ZHAGateway = hass.data[zha_const.DATA_ZHA][ - zha_const.DATA_ZHA_GATEWAY - ] - all_domain_occurrences = [] + entity_registry = er.async_get(hass) + + entity_domains: list[Platform] = [] + all_domain_occurrences: list[Platform] = [] + for member in group.members: if member.device.is_coordinator: continue entities = async_entries_for_device( - zha_gateway.ha_entity_registry, + entity_registry, member.device.device_id, include_disabled_entities=True, ) all_domain_occurrences.extend( [ - entity.domain + cast(Platform, entity.domain) for entity in entities if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS ] diff --git a/homeassistant/components/zha/core/endpoint.py b/homeassistant/components/zha/core/endpoint.py index bdef5ac46af..c87ee60d6b3 100644 --- a/homeassistant/components/zha/core/endpoint.py +++ b/homeassistant/components/zha/core/endpoint.py @@ -16,6 +16,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send from . import const, discovery, registries from .cluster_handlers import ClusterHandler from .cluster_handlers.general import MultistateInput +from .helpers import get_zha_data if TYPE_CHECKING: from .cluster_handlers import ClientClusterHandler @@ -195,7 +196,7 @@ class Endpoint: def async_new_entity( self, - platform: Platform | str, + platform: Platform, entity_class: CALLABLE_T, unique_id: str, cluster_handlers: list[ClusterHandler], @@ -206,7 +207,8 @@ class Endpoint: if self.device.status == DeviceStatus.INITIALIZED: return - self.device.hass.data[const.DATA_ZHA][platform].append( + zha_data = get_zha_data(self.device.hass) + zha_data.platforms[platform].append( (entity_class, (unique_id, self.device, cluster_handlers)) ) diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 5cc2cd9a4b9..5fe84005d7a 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -46,9 +46,6 @@ from .const import ( CONF_RADIO_TYPE, CONF_USE_THREAD, CONF_ZIGPY, - DATA_ZHA, - DATA_ZHA_BRIDGE_ID, - DATA_ZHA_GATEWAY, DEBUG_COMP_BELLOWS, DEBUG_COMP_ZHA, DEBUG_COMP_ZIGPY, @@ -87,6 +84,7 @@ from .const import ( ) from .device import DeviceStatus, ZHADevice from .group import GroupMember, ZHAGroup +from .helpers import get_zha_data from .registries import GROUP_ENTITY_DOMAINS if TYPE_CHECKING: @@ -123,8 +121,6 @@ class ZHAGateway: """Gateway that handles events that happen on the ZHA Zigbee network.""" # -- Set in async_initialize -- - ha_device_registry: dr.DeviceRegistry - ha_entity_registry: er.EntityRegistry application_controller: ControllerApplication radio_description: str @@ -132,7 +128,7 @@ class ZHAGateway: self, hass: HomeAssistant, config: ConfigType, config_entry: ConfigEntry ) -> None: """Initialize the gateway.""" - self._hass = hass + self.hass = hass self._config = config self._devices: dict[EUI64, ZHADevice] = {} self._groups: dict[int, ZHAGroup] = {} @@ -159,7 +155,7 @@ class ZHAGateway: app_config = self._config.get(CONF_ZIGPY, {}) database = self._config.get( CONF_DATABASE, - self._hass.config.path(DEFAULT_DATABASE_NAME), + self.hass.config.path(DEFAULT_DATABASE_NAME), ) app_config[CONF_DATABASE] = database app_config[CONF_DEVICE] = self.config_entry.data[CONF_DEVICE] @@ -191,11 +187,8 @@ class ZHAGateway: async def async_initialize(self) -> None: """Initialize controller and connect radio.""" - discovery.PROBE.initialize(self._hass) - discovery.GROUP_PROBE.initialize(self._hass) - - self.ha_device_registry = dr.async_get(self._hass) - self.ha_entity_registry = er.async_get(self._hass) + discovery.PROBE.initialize(self.hass) + discovery.GROUP_PROBE.initialize(self.hass) app_controller_cls, app_config = self.get_application_controller_data() self.application_controller = await app_controller_cls.new( @@ -225,8 +218,8 @@ class ZHAGateway: else: break - self._hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self - self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str(self.coordinator_ieee) + zha_data = get_zha_data(self.hass) + zha_data.gateway = self self.coordinator_zha_device = self._async_get_or_create_device( self._find_coordinator_device(), restored=True @@ -301,7 +294,7 @@ class ZHAGateway: # background the fetching of state for mains powered devices self.config_entry.async_create_background_task( - self._hass, fetch_updated_state(), "zha.gateway-fetch_updated_state" + self.hass, fetch_updated_state(), "zha.gateway-fetch_updated_state" ) def device_joined(self, device: zigpy.device.Device) -> None: @@ -311,7 +304,7 @@ class ZHAGateway: address """ async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: ZHA_GW_MSG_DEVICE_JOINED, @@ -327,7 +320,7 @@ class ZHAGateway: """Handle a device initialization without quirks loaded.""" manuf = device.manufacturer async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: ZHA_GW_MSG_RAW_INIT, @@ -344,7 +337,7 @@ class ZHAGateway: def device_initialized(self, device: zigpy.device.Device) -> None: """Handle device joined and basic information discovered.""" - self._hass.async_create_task(self.async_device_initialized(device)) + self.hass.async_create_task(self.async_device_initialized(device)) def device_left(self, device: zigpy.device.Device) -> None: """Handle device leaving the network.""" @@ -359,7 +352,7 @@ class ZHAGateway: zha_group.info("group_member_removed - endpoint: %s", endpoint) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) async_dispatcher_send( - self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}" + self.hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}" ) def group_member_added( @@ -371,7 +364,7 @@ class ZHAGateway: zha_group.info("group_member_added - endpoint: %s", endpoint) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) async_dispatcher_send( - self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}" + self.hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{zigpy_group.group_id:04x}" ) if len(zha_group.members) == 2: # we need to do this because there wasn't already @@ -399,7 +392,7 @@ class ZHAGateway: zha_group = self._groups.get(zigpy_group.group_id) if zha_group is not None: async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: gateway_message_type, @@ -416,9 +409,11 @@ class ZHAGateway: remove_tasks.append(entity_ref.remove_future) if remove_tasks: await asyncio.wait(remove_tasks) - reg_device = self.ha_device_registry.async_get(device.device_id) + + device_registry = dr.async_get(self.hass) + reg_device = device_registry.async_get(device.device_id) if reg_device is not None: - self.ha_device_registry.async_remove_device(reg_device.id) + device_registry.async_remove_device(reg_device.id) def device_removed(self, device: zigpy.device.Device) -> None: """Handle device being removed from the network.""" @@ -427,14 +422,14 @@ class ZHAGateway: if zha_device is not None: device_info = zha_device.zha_device_info zha_device.async_cleanup_handles() - async_dispatcher_send(self._hass, f"{SIGNAL_REMOVE}_{str(zha_device.ieee)}") - self._hass.async_create_task( + async_dispatcher_send(self.hass, f"{SIGNAL_REMOVE}_{str(zha_device.ieee)}") + self.hass.async_create_task( self._async_remove_device(zha_device, entity_refs), "ZHAGateway._async_remove_device", ) if device_info is not None: async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: ZHA_GW_MSG_DEVICE_REMOVED, @@ -488,9 +483,10 @@ class ZHAGateway: ] # then we get all group entity entries tied to the coordinator + entity_registry = er.async_get(self.hass) assert self.coordinator_zha_device all_group_entity_entries = er.async_entries_for_device( - self.ha_entity_registry, + entity_registry, self.coordinator_zha_device.device_id, include_disabled_entities=True, ) @@ -508,7 +504,7 @@ class ZHAGateway: _LOGGER.debug( "cleaning up entity registry entry for entity: %s", entry.entity_id ) - self.ha_entity_registry.async_remove(entry.entity_id) + entity_registry.async_remove(entry.entity_id) @property def coordinator_ieee(self) -> EUI64: @@ -582,9 +578,11 @@ class ZHAGateway: ) -> ZHADevice: """Get or create a ZHA device.""" if (zha_device := self._devices.get(zigpy_device.ieee)) is None: - zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored) + zha_device = ZHADevice.new(self.hass, zigpy_device, self, restored) self._devices[zigpy_device.ieee] = zha_device - device_registry_device = self.ha_device_registry.async_get_or_create( + + device_registry = dr.async_get(self.hass) + device_registry_device = device_registry.async_get_or_create( config_entry_id=self.config_entry.entry_id, connections={(dr.CONNECTION_ZIGBEE, str(zha_device.ieee))}, identifiers={(DOMAIN, str(zha_device.ieee))}, @@ -600,7 +598,7 @@ class ZHAGateway: """Get or create a ZHA group.""" zha_group = self._groups.get(zigpy_group.group_id) if zha_group is None: - zha_group = ZHAGroup(self._hass, self, zigpy_group) + zha_group = ZHAGroup(self.hass, self, zigpy_group) self._groups[zigpy_group.group_id] = zha_group return zha_group @@ -645,7 +643,7 @@ class ZHAGateway: device_info = zha_device.zha_device_info device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.INITIALIZED.name async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT, @@ -659,7 +657,7 @@ class ZHAGateway: await zha_device.async_configure() device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.CONFIGURED.name async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT, @@ -667,7 +665,7 @@ class ZHAGateway: }, ) await zha_device.async_initialize(from_cache=False) - async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) + async_dispatcher_send(self.hass, SIGNAL_ADD_ENTITIES) async def _async_device_rejoined(self, zha_device: ZHADevice) -> None: _LOGGER.debug( @@ -681,7 +679,7 @@ class ZHAGateway: device_info = zha_device.device_info device_info[DEVICE_PAIRING_STATUS] = DevicePairingStatus.CONFIGURED.name async_dispatcher_send( - self._hass, + self.hass, ZHA_GW_MSG, { ATTR_TYPE: ZHA_GW_MSG_DEVICE_FULL_INIT, diff --git a/homeassistant/components/zha/core/group.py b/homeassistant/components/zha/core/group.py index ebea2f4ac41..519668052e0 100644 --- a/homeassistant/components/zha/core/group.py +++ b/homeassistant/components/zha/core/group.py @@ -11,6 +11,7 @@ import zigpy.group from zigpy.types.named import EUI64 from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_registry import async_entries_for_device from .helpers import LogMixin @@ -32,8 +33,8 @@ class GroupMember(NamedTuple): class GroupEntityReference(NamedTuple): """Reference to a group entity.""" - name: str - original_name: str + name: str | None + original_name: str | None entity_id: int @@ -80,20 +81,30 @@ class ZHAGroupMember(LogMixin): @property def associated_entities(self) -> list[dict[str, Any]]: """Return the list of entities that were derived from this endpoint.""" - ha_entity_registry = self.device.gateway.ha_entity_registry + entity_registry = er.async_get(self._zha_device.hass) zha_device_registry = self.device.gateway.device_registry - return [ - GroupEntityReference( - ha_entity_registry.async_get(entity_ref.reference_id).name, - ha_entity_registry.async_get(entity_ref.reference_id).original_name, - entity_ref.reference_id, - )._asdict() - for entity_ref in zha_device_registry.get(self.device.ieee) - if list(entity_ref.cluster_handlers.values())[ - 0 - ].cluster.endpoint.endpoint_id - == self.endpoint_id - ] + + entity_info = [] + + for entity_ref in zha_device_registry.get(self.device.ieee): + entity = entity_registry.async_get(entity_ref.reference_id) + handler = list(entity_ref.cluster_handlers.values())[0] + + if ( + entity is None + or handler.cluster.endpoint.endpoint_id != self.endpoint_id + ): + continue + + entity_info.append( + GroupEntityReference( + name=entity.name, + original_name=entity.original_name, + entity_id=entity_ref.reference_id, + )._asdict() + ) + + return entity_info async def async_remove_from_group(self) -> None: """Remove the device endpoint from the provided zigbee group.""" @@ -204,12 +215,14 @@ class ZHAGroup(LogMixin): def get_domain_entity_ids(self, domain: str) -> list[str]: """Return entity ids from the entity domain for this group.""" + entity_registry = er.async_get(self.hass) domain_entity_ids: list[str] = [] + for member in self.members: if member.device.is_coordinator: continue entities = async_entries_for_device( - self._zha_gateway.ha_entity_registry, + entity_registry, member.device.device_id, include_disabled_entities=True, ) diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py index 7b0d062738b..4df546b449c 100644 --- a/homeassistant/components/zha/core/helpers.py +++ b/homeassistant/components/zha/core/helpers.py @@ -7,7 +7,9 @@ from __future__ import annotations import asyncio import binascii +import collections from collections.abc import Callable, Iterator +import dataclasses from dataclasses import dataclass import enum import functools @@ -26,16 +28,12 @@ from zigpy.zcl.foundation import CommandSchema import zigpy.zdo.types as zdo_types from homeassistant.config_entries import ConfigEntry +from homeassistant.const import Platform from homeassistant.core import HomeAssistant, State, callback from homeassistant.helpers import config_validation as cv, device_registry as dr +from homeassistant.helpers.typing import ConfigType -from .const import ( - CLUSTER_TYPE_IN, - CLUSTER_TYPE_OUT, - CUSTOM_CONFIGURATION, - DATA_ZHA, - DATA_ZHA_GATEWAY, -) +from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, CUSTOM_CONFIGURATION, DATA_ZHA from .registries import BINDABLE_CLUSTERS if TYPE_CHECKING: @@ -221,7 +219,7 @@ def async_get_zha_config_value( def async_cluster_exists(hass, cluster_id, skip_coordinator=True): """Determine if a device containing the specified in cluster is paired.""" - zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) zha_devices = zha_gateway.devices.values() for zha_device in zha_devices: if skip_coordinator and zha_device.is_coordinator: @@ -244,7 +242,7 @@ def async_get_zha_device(hass: HomeAssistant, device_id: str) -> ZHADevice: if not registry_device: _LOGGER.error("Device id `%s` not found in registry", device_id) raise KeyError(f"Device id `{device_id}` not found in registry.") - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) try: ieee_address = list(registry_device.identifiers)[0][1] ieee = zigpy.types.EUI64.convert(ieee_address) @@ -421,3 +419,30 @@ def qr_to_install_code(qr_code: str) -> tuple[zigpy.types.EUI64, bytes]: return ieee, install_code raise vol.Invalid(f"couldn't convert qr code: {qr_code}") + + +@dataclasses.dataclass(kw_only=True, slots=True) +class ZHAData: + """ZHA component data stored in `hass.data`.""" + + yaml_config: ConfigType = dataclasses.field(default_factory=dict) + platforms: collections.defaultdict[Platform, list] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + gateway: ZHAGateway | None = dataclasses.field(default=None) + device_trigger_cache: dict[str, tuple[str, dict]] = dataclasses.field( + default_factory=dict + ) + + +def get_zha_data(hass: HomeAssistant) -> ZHAData: + """Get the global ZHA data object.""" + return hass.data.get(DATA_ZHA, ZHAData()) + + +def get_zha_gateway(hass: HomeAssistant) -> ZHAGateway: + """Get the ZHA gateway object.""" + if (zha_gateway := get_zha_data(hass).gateway) is None: + raise ValueError("No gateway object exists") + + return zha_gateway diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 713d10ddf70..74f724bdc49 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -269,15 +269,15 @@ class ZHAEntityRegistry: def __init__(self) -> None: """Initialize Registry instance.""" self._strict_registry: dict[ - str, dict[MatchRule, type[ZhaEntity]] + Platform, dict[MatchRule, type[ZhaEntity]] ] = collections.defaultdict(dict) self._multi_entity_registry: dict[ - str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] + Platform, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] ] = collections.defaultdict( lambda: collections.defaultdict(lambda: collections.defaultdict(list)) ) self._config_diagnostic_entity_registry: dict[ - str, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] + Platform, dict[int | str | None, dict[MatchRule, list[type[ZhaEntity]]]] ] = collections.defaultdict( lambda: collections.defaultdict(lambda: collections.defaultdict(list)) ) @@ -288,7 +288,7 @@ class ZHAEntityRegistry: def get_entity( self, - component: str, + component: Platform, manufacturer: str, model: str, cluster_handlers: list[ClusterHandler], @@ -310,10 +310,12 @@ class ZHAEntityRegistry: model: str, cluster_handlers: list[ClusterHandler], quirk_class: str, - ) -> tuple[dict[str, list[EntityClassAndClusterHandlers]], list[ClusterHandler]]: + ) -> tuple[ + dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler] + ]: """Match ZHA cluster handlers to potentially multiple ZHA Entity classes.""" result: dict[ - str, list[EntityClassAndClusterHandlers] + Platform, list[EntityClassAndClusterHandlers] ] = collections.defaultdict(list) all_claimed: set[ClusterHandler] = set() for component, stop_match_groups in self._multi_entity_registry.items(): @@ -341,10 +343,12 @@ class ZHAEntityRegistry: model: str, cluster_handlers: list[ClusterHandler], quirk_class: str, - ) -> tuple[dict[str, list[EntityClassAndClusterHandlers]], list[ClusterHandler]]: + ) -> tuple[ + dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler] + ]: """Match ZHA cluster handlers to potentially multiple ZHA Entity classes.""" result: dict[ - str, list[EntityClassAndClusterHandlers] + Platform, list[EntityClassAndClusterHandlers] ] = collections.defaultdict(list) all_claimed: set[ClusterHandler] = set() for ( @@ -375,7 +379,7 @@ class ZHAEntityRegistry: def strict_match( self, - component: str, + component: Platform, cluster_handler_names: set[str] | str | None = None, generic_ids: set[str] | str | None = None, manufacturers: Callable | set[str] | str | None = None, @@ -406,7 +410,7 @@ class ZHAEntityRegistry: def multipass_match( self, - component: str, + component: Platform, cluster_handler_names: set[str] | str | None = None, generic_ids: set[str] | str | None = None, manufacturers: Callable | set[str] | str | None = None, @@ -441,7 +445,7 @@ class ZHAEntityRegistry: def config_diagnostic_match( self, - component: str, + component: Platform, cluster_handler_names: set[str] | str | None = None, generic_ids: set[str] | str | None = None, manufacturers: Callable | set[str] | str | None = None, @@ -475,7 +479,7 @@ class ZHAEntityRegistry: return decorator def group_match( - self, component: str + self, component: Platform ) -> Callable[[_ZhaGroupEntityT], _ZhaGroupEntityT]: """Decorate a group match rule.""" diff --git a/homeassistant/components/zha/cover.py b/homeassistant/components/zha/cover.py index 0d7062173ca..f2aed0390f3 100644 --- a/homeassistant/components/zha/cover.py +++ b/homeassistant/components/zha/cover.py @@ -33,11 +33,11 @@ from .core.const import ( CLUSTER_HANDLER_LEVEL, CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_SHADE, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, SIGNAL_SET_LEVEL, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -56,7 +56,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation cover from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.COVER] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.COVER] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/device_tracker.py b/homeassistant/components/zha/device_tracker.py index bda346624dd..ea27c58eb19 100644 --- a/homeassistant/components/zha/device_tracker.py +++ b/homeassistant/components/zha/device_tracker.py @@ -15,10 +15,10 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from .core import discovery from .core.const import ( CLUSTER_HANDLER_POWER_CONFIGURATION, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity from .sensor import Battery @@ -32,7 +32,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation device tracker from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.DEVICE_TRACKER] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.DEVICE_TRACKER] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/device_trigger.py b/homeassistant/components/zha/device_trigger.py index 7a479443377..a2ae734b8fc 100644 --- a/homeassistant/components/zha/device_trigger.py +++ b/homeassistant/components/zha/device_trigger.py @@ -14,8 +14,8 @@ from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType from . import DOMAIN as ZHA_DOMAIN -from .core.const import DATA_ZHA, DATA_ZHA_DEVICE_TRIGGER_CACHE, ZHA_EVENT -from .core.helpers import async_get_zha_device +from .core.const import ZHA_EVENT +from .core.helpers import async_get_zha_device, get_zha_data CONF_SUBTYPE = "subtype" DEVICE = "device" @@ -32,13 +32,13 @@ def _get_device_trigger_data(hass: HomeAssistant, device_id: str) -> tuple[str, # First, try checking to see if the device itself is accessible try: zha_device = async_get_zha_device(hass, device_id) - except KeyError: + except ValueError: pass else: return str(zha_device.ieee), zha_device.device_automation_triggers # If not, check the trigger cache but allow any `KeyError`s to propagate - return hass.data[DATA_ZHA][DATA_ZHA_DEVICE_TRIGGER_CACHE][device_id] + return get_zha_data(hass).device_trigger_cache[device_id] async def async_validate_trigger_config( diff --git a/homeassistant/components/zha/diagnostics.py b/homeassistant/components/zha/diagnostics.py index 966f35fe98b..0fa1de5ff0e 100644 --- a/homeassistant/components/zha/diagnostics.py +++ b/homeassistant/components/zha/diagnostics.py @@ -25,14 +25,10 @@ from .core.const import ( ATTR_PROFILE_ID, ATTR_VALUE, CONF_ALARM_MASTER_CODE, - DATA_ZHA, - DATA_ZHA_CONFIG, - DATA_ZHA_GATEWAY, UNKNOWN, ) from .core.device import ZHADevice -from .core.gateway import ZHAGateway -from .core.helpers import async_get_zha_device +from .core.helpers import async_get_zha_device, get_zha_data, get_zha_gateway KEYS_TO_REDACT = { ATTR_IEEE, @@ -66,18 +62,18 @@ async def async_get_config_entry_diagnostics( hass: HomeAssistant, config_entry: ConfigEntry ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - config: dict = hass.data[DATA_ZHA].get(DATA_ZHA_CONFIG, {}) - gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_data = get_zha_data(hass) + app = get_zha_gateway(hass).application_controller - energy_scan = await gateway.application_controller.energy_scan( + energy_scan = await app.energy_scan( channels=Channels.ALL_CHANNELS, duration_exp=4, count=1 ) return async_redact_data( { - "config": config, + "config": zha_data.yaml_config, "config_entry": config_entry.as_dict(), - "application_state": shallow_asdict(gateway.application_controller.state), + "application_state": shallow_asdict(app.state), "energy_scan": { channel: 100 * energy / 255 for channel, energy in energy_scan.items() }, diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index f2b16a37834..5722d91116a 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -26,14 +26,12 @@ from homeassistant.helpers.typing import EventType from .core.const import ( ATTR_MANUFACTURER, ATTR_MODEL, - DATA_ZHA, - DATA_ZHA_BRIDGE_ID, DOMAIN, SIGNAL_GROUP_ENTITY_REMOVED, SIGNAL_GROUP_MEMBERSHIP_CHANGE, SIGNAL_REMOVE, ) -from .core.helpers import LogMixin +from .core.helpers import LogMixin, get_zha_gateway if TYPE_CHECKING: from .core.cluster_handlers import ClusterHandler @@ -83,13 +81,16 @@ class BaseZhaEntity(LogMixin, entity.Entity): """Return a device description for device registry.""" zha_device_info = self._zha_device.device_info ieee = zha_device_info["ieee"] + + zha_gateway = get_zha_gateway(self.hass) + return DeviceInfo( connections={(CONNECTION_ZIGBEE, ieee)}, identifiers={(DOMAIN, ieee)}, manufacturer=zha_device_info[ATTR_MANUFACTURER], model=zha_device_info[ATTR_MODEL], name=zha_device_info[ATTR_NAME], - via_device=(DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]), + via_device=(DOMAIN, zha_gateway.coordinator_ieee), ) @callback diff --git a/homeassistant/components/zha/fan.py b/homeassistant/components/zha/fan.py index a24272c9a7a..73b128db109 100644 --- a/homeassistant/components/zha/fan.py +++ b/homeassistant/components/zha/fan.py @@ -28,12 +28,8 @@ from homeassistant.util.percentage import ( from .core import discovery from .core.cluster_handlers import wrap_zigpy_exceptions -from .core.const import ( - CLUSTER_HANDLER_FAN, - DATA_ZHA, - SIGNAL_ADD_ENTITIES, - SIGNAL_ATTR_UPDATED, -) +from .core.const import CLUSTER_HANDLER_FAN, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity, ZhaGroupEntity @@ -65,7 +61,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation fan from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.FAN] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.FAN] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/light.py b/homeassistant/components/zha/light.py index 2ec42431498..967d0fc9134 100644 --- a/homeassistant/components/zha/light.py +++ b/homeassistant/components/zha/light.py @@ -47,13 +47,12 @@ from .core.const import ( CONF_ENABLE_ENHANCED_LIGHT_TRANSITION, CONF_ENABLE_LIGHT_TRANSITIONING_FLAG, CONF_GROUP_MEMBERS_ASSUME_STATE, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, SIGNAL_SET_LEVEL, ZHA_OPTIONS, ) -from .core.helpers import LogMixin, async_get_zha_config_value +from .core.helpers import LogMixin, async_get_zha_config_value, get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity, ZhaGroupEntity @@ -97,7 +96,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation light from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.LIGHT] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.LIGHT] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/lock.py b/homeassistant/components/zha/lock.py index 1e68e95c881..9bac9a59a38 100644 --- a/homeassistant/components/zha/lock.py +++ b/homeassistant/components/zha/lock.py @@ -20,10 +20,10 @@ from homeassistant.helpers.typing import StateType from .core import discovery from .core.const import ( CLUSTER_HANDLER_DOORLOCK, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -45,7 +45,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation Door Lock from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.LOCK] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.LOCK] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/number.py b/homeassistant/components/zha/number.py index c12060eb2a8..b6876155312 100644 --- a/homeassistant/components/zha/number.py +++ b/homeassistant/components/zha/number.py @@ -20,10 +20,10 @@ from .core.const import ( CLUSTER_HANDLER_COLOR, CLUSTER_HANDLER_INOVELLI, CLUSTER_HANDLER_LEVEL, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -258,7 +258,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation Analog Output from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.NUMBER] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.NUMBER] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/radio_manager.py b/homeassistant/components/zha/radio_manager.py index df30a85cd7b..ca030600751 100644 --- a/homeassistant/components/zha/radio_manager.py +++ b/homeassistant/components/zha/radio_manager.py @@ -26,12 +26,11 @@ from .core.const import ( CONF_DATABASE, CONF_RADIO_TYPE, CONF_ZIGPY, - DATA_ZHA, - DATA_ZHA_CONFIG, DEFAULT_DATABASE_NAME, EZSP_OVERWRITE_EUI64, RadioType, ) +from .core.helpers import get_zha_data # Only the common radio types will be autoprobed, ordered by new device popularity. # XBee takes too long to probe since it scans through all possible bauds and likely has @@ -145,7 +144,7 @@ class ZhaRadioManager: """Connect to the radio with the current config and then clean up.""" assert self.radio_type is not None - config = self.hass.data.get(DATA_ZHA, {}).get(DATA_ZHA_CONFIG, {}) + config = get_zha_data(self.hass).yaml_config app_config = config.get(CONF_ZIGPY, {}).copy() database_path = config.get( diff --git a/homeassistant/components/zha/select.py b/homeassistant/components/zha/select.py index 018f24675e7..fa2e124fd05 100644 --- a/homeassistant/components/zha/select.py +++ b/homeassistant/components/zha/select.py @@ -23,11 +23,11 @@ from .core.const import ( CLUSTER_HANDLER_IAS_WD, CLUSTER_HANDLER_INOVELLI, CLUSTER_HANDLER_ON_OFF, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, Strobe, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -48,7 +48,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation siren from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.SELECT] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.SELECT] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/sensor.py b/homeassistant/components/zha/sensor.py index 535733230b9..1e166675b5b 100644 --- a/homeassistant/components/zha/sensor.py +++ b/homeassistant/components/zha/sensor.py @@ -57,10 +57,10 @@ from .core.const import ( CLUSTER_HANDLER_SOIL_MOISTURE, CLUSTER_HANDLER_TEMPERATURE, CLUSTER_HANDLER_THERMOSTAT, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES from .entity import ZhaEntity @@ -99,7 +99,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation sensor from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.SENSOR] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.SENSOR] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/siren.py b/homeassistant/components/zha/siren.py index a4c699d515b..86cadb62519 100644 --- a/homeassistant/components/zha/siren.py +++ b/homeassistant/components/zha/siren.py @@ -25,7 +25,6 @@ from .core import discovery from .core.cluster_handlers.security import IasWd from .core.const import ( CLUSTER_HANDLER_IAS_WD, - DATA_ZHA, SIGNAL_ADD_ENTITIES, WARNING_DEVICE_MODE_BURGLAR, WARNING_DEVICE_MODE_EMERGENCY, @@ -39,6 +38,7 @@ from .core.const import ( WARNING_DEVICE_STROBE_NO, Strobe, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -56,7 +56,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation siren from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.SIREN] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.SIREN] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/switch.py b/homeassistant/components/zha/switch.py index 8707dda629f..eff8f727c1c 100644 --- a/homeassistant/components/zha/switch.py +++ b/homeassistant/components/zha/switch.py @@ -20,10 +20,10 @@ from .core.const import ( CLUSTER_HANDLER_BASIC, CLUSTER_HANDLER_INOVELLI, CLUSTER_HANDLER_ON_OFF, - DATA_ZHA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, ) +from .core.helpers import get_zha_data from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity, ZhaGroupEntity @@ -46,7 +46,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Zigbee Home Automation switch from config entry.""" - entities_to_create = hass.data[DATA_ZHA][Platform.SWITCH] + zha_data = get_zha_data(hass) + entities_to_create = zha_data.platforms[Platform.SWITCH] unsub = async_dispatcher_connect( hass, diff --git a/homeassistant/components/zha/websocket_api.py b/homeassistant/components/zha/websocket_api.py index 97862bd36f0..51941248f03 100644 --- a/homeassistant/components/zha/websocket_api.py +++ b/homeassistant/components/zha/websocket_api.py @@ -16,6 +16,7 @@ import zigpy.zdo.types as zdo_types from homeassistant.components import websocket_api from homeassistant.const import ATTR_COMMAND, ATTR_ID, ATTR_NAME from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.helpers import entity_registry as er import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.service import async_register_admin_service @@ -52,8 +53,6 @@ from .core.const import ( CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, CUSTOM_CONFIGURATION, - DATA_ZHA, - DATA_ZHA_GATEWAY, DOMAIN, EZSP_OVERWRITE_EUI64, GROUP_ID, @@ -77,6 +76,7 @@ from .core.helpers import ( cluster_command_schema_to_vol_schema, convert_install_code, get_matched_clusters, + get_zha_gateway, qr_to_install_code, ) @@ -301,7 +301,7 @@ async def websocket_permit_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Permit ZHA zigbee devices.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) duration: int = msg[ATTR_DURATION] ieee: EUI64 | None = msg.get(ATTR_IEEE) @@ -348,7 +348,7 @@ async def websocket_get_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA devices.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) devices = [device.zha_device_info for device in zha_gateway.devices.values()] connection.send_result(msg[ID], devices) @@ -357,7 +357,8 @@ async def websocket_get_devices( def _get_entity_name( zha_gateway: ZHAGateway, entity_ref: EntityReference ) -> str | None: - entry = zha_gateway.ha_entity_registry.async_get(entity_ref.reference_id) + entity_registry = er.async_get(zha_gateway.hass) + entry = entity_registry.async_get(entity_ref.reference_id) return entry.name if entry else None @@ -365,7 +366,8 @@ def _get_entity_name( def _get_entity_original_name( zha_gateway: ZHAGateway, entity_ref: EntityReference ) -> str | None: - entry = zha_gateway.ha_entity_registry.async_get(entity_ref.reference_id) + entity_registry = er.async_get(zha_gateway.hass) + entry = entity_registry.async_get(entity_ref.reference_id) return entry.original_name if entry else None @@ -376,7 +378,7 @@ async def websocket_get_groupable_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA devices that can be grouped.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) devices = [device for device in zha_gateway.devices.values() if device.is_groupable] groupable_devices = [] @@ -414,7 +416,7 @@ async def websocket_get_groups( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA groups.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) groups = [group.group_info for group in zha_gateway.groups.values()] connection.send_result(msg[ID], groups) @@ -431,7 +433,7 @@ async def websocket_get_device( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA devices.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = msg[ATTR_IEEE] if not (zha_device := zha_gateway.devices.get(ieee)): @@ -458,7 +460,7 @@ async def websocket_get_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA group.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) group_id: int = msg[GROUP_ID] if not (zha_group := zha_gateway.groups.get(group_id)): @@ -487,7 +489,7 @@ async def websocket_add_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Add a new ZHA group.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) group_name: str = msg[GROUP_NAME] group_id: int | None = msg.get(GROUP_ID) members: list[GroupMember] | None = msg.get(ATTR_MEMBERS) @@ -508,7 +510,7 @@ async def websocket_remove_groups( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Remove the specified ZHA groups.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) group_ids: list[int] = msg[GROUP_IDS] if len(group_ids) > 1: @@ -535,7 +537,7 @@ async def websocket_add_group_members( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Add members to a ZHA group.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) group_id: int = msg[GROUP_ID] members: list[GroupMember] = msg[ATTR_MEMBERS] @@ -565,7 +567,7 @@ async def websocket_remove_group_members( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Remove members from a ZHA group.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) group_id: int = msg[GROUP_ID] members: list[GroupMember] = msg[ATTR_MEMBERS] @@ -594,7 +596,7 @@ async def websocket_reconfigure_node( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Reconfigure a ZHA nodes entities by its ieee address.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = msg[ATTR_IEEE] device: ZHADevice | None = zha_gateway.get_device(ieee) @@ -629,7 +631,7 @@ async def websocket_update_topology( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Update the ZHA network topology.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) hass.async_create_task(zha_gateway.application_controller.topology.scan()) @@ -645,7 +647,7 @@ async def websocket_device_clusters( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Return a list of device clusters.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = msg[ATTR_IEEE] zha_device = zha_gateway.get_device(ieee) response_clusters = [] @@ -689,7 +691,7 @@ async def websocket_device_cluster_attributes( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Return a list of cluster attributes.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = msg[ATTR_IEEE] endpoint_id: int = msg[ATTR_ENDPOINT_ID] cluster_id: int = msg[ATTR_CLUSTER_ID] @@ -736,7 +738,7 @@ async def websocket_device_cluster_commands( """Return a list of cluster commands.""" import voluptuous_serialize # pylint: disable=import-outside-toplevel - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = msg[ATTR_IEEE] endpoint_id: int = msg[ATTR_ENDPOINT_ID] cluster_id: int = msg[ATTR_CLUSTER_ID] @@ -806,7 +808,7 @@ async def websocket_read_zigbee_cluster_attributes( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Read zigbee attribute for cluster on ZHA entity.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = msg[ATTR_IEEE] endpoint_id: int = msg[ATTR_ENDPOINT_ID] cluster_id: int = msg[ATTR_CLUSTER_ID] @@ -860,7 +862,7 @@ async def websocket_get_bindable_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Directly bind devices.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) source_ieee: EUI64 = msg[ATTR_IEEE] source_device = zha_gateway.get_device(source_ieee) @@ -894,7 +896,7 @@ async def websocket_bind_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Directly bind devices.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] target_ieee: EUI64 = msg[ATTR_TARGET_IEEE] await async_binding_operation( @@ -923,7 +925,7 @@ async def websocket_unbind_devices( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Remove a direct binding between devices.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] target_ieee: EUI64 = msg[ATTR_TARGET_IEEE] await async_binding_operation( @@ -953,7 +955,7 @@ async def websocket_bind_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Directly bind a device to a group.""" - zha_gateway: ZHAGateway = get_gateway(hass) + zha_gateway = get_zha_gateway(hass) source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] group_id: int = msg[GROUP_ID] bindings: list[ClusterBinding] = msg[BINDINGS] @@ -977,7 +979,7 @@ async def websocket_unbind_group( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Unbind a device from a group.""" - zha_gateway: ZHAGateway = get_gateway(hass) + zha_gateway = get_zha_gateway(hass) source_ieee: EUI64 = msg[ATTR_SOURCE_IEEE] group_id: int = msg[GROUP_ID] bindings: list[ClusterBinding] = msg[BINDINGS] @@ -987,11 +989,6 @@ async def websocket_unbind_group( connection.send_result(msg[ID]) -def get_gateway(hass: HomeAssistant) -> ZHAGateway: - """Return Gateway, mainly as fixture for mocking during testing.""" - return hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] - - async def async_binding_operation( zha_gateway: ZHAGateway, source_ieee: EUI64, @@ -1047,7 +1044,7 @@ async def websocket_get_configuration( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA configuration.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) import voluptuous_serialize # pylint: disable=import-outside-toplevel def custom_serializer(schema: Any) -> Any: @@ -1094,7 +1091,7 @@ async def websocket_update_zha_configuration( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Update the ZHA configuration.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) options = zha_gateway.config_entry.options data_to_save = {**options, **{CUSTOM_CONFIGURATION: msg["data"]}} @@ -1141,7 +1138,7 @@ async def websocket_get_network_settings( ) -> None: """Get ZHA network settings.""" backup = async_get_active_network_settings(hass) - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) connection.send_result( msg[ID], { @@ -1159,7 +1156,7 @@ async def websocket_list_network_backups( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Get ZHA network settings.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) application_controller = zha_gateway.application_controller # Serialize known backups @@ -1175,7 +1172,7 @@ async def websocket_create_network_backup( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Create a ZHA network backup.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) application_controller = zha_gateway.application_controller # This can take 5-30s @@ -1202,7 +1199,7 @@ async def websocket_restore_network_backup( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: """Restore a ZHA network backup.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) application_controller = zha_gateway.application_controller backup = msg["backup"] @@ -1240,7 +1237,7 @@ async def websocket_change_channel( @callback def async_load_api(hass: HomeAssistant) -> None: """Set up the web socket API.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) application_controller = zha_gateway.application_controller async def permit(service: ServiceCall) -> None: @@ -1278,7 +1275,7 @@ def async_load_api(hass: HomeAssistant) -> None: async def remove(service: ServiceCall) -> None: """Remove a node from the network.""" - zha_gateway: ZHAGateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) ieee: EUI64 = service.data[ATTR_IEEE] zha_device: ZHADevice | None = zha_gateway.get_device(ieee) if zha_device is not None and zha_device.is_active_coordinator: diff --git a/tests/components/zha/common.py b/tests/components/zha/common.py index db1da3721ee..44155d741b7 100644 --- a/tests/components/zha/common.py +++ b/tests/components/zha/common.py @@ -9,7 +9,10 @@ import zigpy.zcl import zigpy.zcl.foundation as zcl_f import homeassistant.components.zha.core.const as zha_const -from homeassistant.components.zha.core.helpers import async_get_zha_config_value +from homeassistant.components.zha.core.helpers import ( + async_get_zha_config_value, + get_zha_gateway, +) from homeassistant.helpers import entity_registry as er import homeassistant.util.dt as dt_util @@ -85,11 +88,6 @@ def update_attribute_cache(cluster): cluster.handle_message(hdr, msg) -def get_zha_gateway(hass): - """Return ZHA gateway from hass.data.""" - return hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] - - def make_attribute(attrid, value, status=0): """Make an attribute.""" attr = zcl_f.Attribute() diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index 7d391872a77..e7dc7316f73 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -22,9 +22,10 @@ import zigpy.zdo.types as zdo_t import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.device as zha_core_device +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.setup import async_setup_component -from . import common +from .common import patch_cluster as common_patch_cluster from tests.common import MockConfigEntry from tests.components.light.conftest import mock_light_profiles # noqa: F401 @@ -277,7 +278,7 @@ def zigpy_device_mock(zigpy_app_controller): for cluster in itertools.chain( endpoint.in_clusters.values(), endpoint.out_clusters.values() ): - common.patch_cluster(cluster) + common_patch_cluster(cluster) if attributes is not None: for ep_id, clusters in attributes.items(): @@ -305,7 +306,7 @@ def zha_device_joined(hass, setup_zha): if setup_zha: await setup_zha_fixture() - zha_gateway = common.get_zha_gateway(hass) + zha_gateway = get_zha_gateway(hass) zha_gateway.application_controller.devices[zigpy_dev.ieee] = zigpy_dev await zha_gateway.async_device_initialized(zigpy_dev) await hass.async_block_till_done() @@ -329,7 +330,7 @@ def zha_device_restored(hass, zigpy_app_controller, setup_zha): if setup_zha: await setup_zha_fixture() - zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + zha_gateway = get_zha_gateway(hass) return zha_gateway.get_device(zigpy_dev.ieee) return _zha_device diff --git a/tests/components/zha/test_api.py b/tests/components/zha/test_api.py index c2cb16efcc8..89742fb1e49 100644 --- a/tests/components/zha/test_api.py +++ b/tests/components/zha/test_api.py @@ -11,6 +11,7 @@ import zigpy.state from homeassistant.components import zha from homeassistant.components.zha import api from homeassistant.components.zha.core.const import RadioType +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.core import HomeAssistant if TYPE_CHECKING: @@ -40,7 +41,7 @@ async def test_async_get_network_settings_inactive( """Test reading settings with an inactive ZHA installation.""" await setup_zha() - gateway = api._get_gateway(hass) + gateway = get_zha_gateway(hass) await zha.async_unload_entry(hass, gateway.config_entry) backup = zigpy.backups.NetworkBackup() @@ -70,7 +71,7 @@ async def test_async_get_network_settings_missing( """Test reading settings with an inactive ZHA installation, no valid channel.""" await setup_zha() - gateway = api._get_gateway(hass) + gateway = get_zha_gateway(hass) await gateway.config_entry.async_unload(hass) # Network settings were never loaded for whatever reason diff --git a/tests/components/zha/test_cluster_handlers.py b/tests/components/zha/test_cluster_handlers.py index 7e0e8eaab85..24162296cd5 100644 --- a/tests/components/zha/test_cluster_handlers.py +++ b/tests/components/zha/test_cluster_handlers.py @@ -20,11 +20,12 @@ import homeassistant.components.zha.core.cluster_handlers as cluster_handlers import homeassistant.components.zha.core.const as zha_const from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.endpoint import Endpoint +from homeassistant.components.zha.core.helpers import get_zha_gateway import homeassistant.components.zha.core.registries as registries from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from .common import get_zha_gateway, make_zcl_header +from .common import make_zcl_header from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE from tests.common import async_capture_events diff --git a/tests/components/zha/test_device_action.py b/tests/components/zha/test_device_action.py index 31ffe9449e2..229fde89f15 100644 --- a/tests/components/zha/test_device_action.py +++ b/tests/components/zha/test_device_action.py @@ -108,21 +108,19 @@ async def test_get_actions(hass: HomeAssistant, device_ias) -> None: ieee_address = str(device_ias[0].ieee) - ha_device_registry = dr.async_get(hass) - reg_device = ha_device_registry.async_get_device( - identifiers={(DOMAIN, ieee_address)} - ) - ha_entity_registry = er.async_get(hass) - siren_level_select = ha_entity_registry.async_get( + device_registry = dr.async_get(hass) + reg_device = device_registry.async_get_device(identifiers={(DOMAIN, ieee_address)}) + entity_registry = er.async_get(hass) + siren_level_select = entity_registry.async_get( "select.fakemanufacturer_fakemodel_default_siren_level" ) - siren_tone_select = ha_entity_registry.async_get( + siren_tone_select = entity_registry.async_get( "select.fakemanufacturer_fakemodel_default_siren_tone" ) - strobe_level_select = ha_entity_registry.async_get( + strobe_level_select = entity_registry.async_get( "select.fakemanufacturer_fakemodel_default_strobe_level" ) - strobe_select = ha_entity_registry.async_get( + strobe_select = entity_registry.async_get( "select.fakemanufacturer_fakemodel_default_strobe" ) @@ -171,13 +169,13 @@ async def test_get_inovelli_actions(hass: HomeAssistant, device_inovelli) -> Non """Test we get the expected actions from a ZHA device.""" inovelli_ieee_address = str(device_inovelli[0].ieee) - ha_device_registry = dr.async_get(hass) - inovelli_reg_device = ha_device_registry.async_get_device( + device_registry = dr.async_get(hass) + inovelli_reg_device = device_registry.async_get_device( identifiers={(DOMAIN, inovelli_ieee_address)} ) - ha_entity_registry = er.async_get(hass) - inovelli_button = ha_entity_registry.async_get("button.inovelli_vzm31_sn_identify") - inovelli_light = ha_entity_registry.async_get("light.inovelli_vzm31_sn_light") + entity_registry = er.async_get(hass) + inovelli_button = entity_registry.async_get("button.inovelli_vzm31_sn_identify") + inovelli_light = entity_registry.async_get("light.inovelli_vzm31_sn_light") actions = await async_get_device_automations( hass, DeviceAutomationType.ACTION, inovelli_reg_device.id @@ -262,11 +260,9 @@ async def test_action(hass: HomeAssistant, device_ias, device_inovelli) -> None: ieee_address = str(zha_device.ieee) inovelli_ieee_address = str(inovelli_zha_device.ieee) - ha_device_registry = dr.async_get(hass) - reg_device = ha_device_registry.async_get_device( - identifiers={(DOMAIN, ieee_address)} - ) - inovelli_reg_device = ha_device_registry.async_get_device( + device_registry = dr.async_get(hass) + reg_device = device_registry.async_get_device(identifiers={(DOMAIN, ieee_address)}) + inovelli_reg_device = device_registry.async_get_device( identifiers={(DOMAIN, inovelli_ieee_address)} ) diff --git a/tests/components/zha/test_device_trigger.py b/tests/components/zha/test_device_trigger.py index 491e2d96d4f..096d83567fe 100644 --- a/tests/components/zha/test_device_trigger.py +++ b/tests/components/zha/test_device_trigger.py @@ -477,6 +477,7 @@ async def test_validate_trigger_config_unloaded_bad_info( # Reload ZHA to persist the device info in the cache await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() await hass.config_entries.async_unload(config_entry.entry_id) ha_device_registry = dr.async_get(hass) diff --git a/tests/components/zha/test_diagnostics.py b/tests/components/zha/test_diagnostics.py index 6bcb321ab14..c13bb36c1c0 100644 --- a/tests/components/zha/test_diagnostics.py +++ b/tests/components/zha/test_diagnostics.py @@ -6,8 +6,8 @@ import zigpy.profiles.zha as zha import zigpy.zcl.clusters.security as security from homeassistant.components.diagnostics import REDACTED -from homeassistant.components.zha.core.const import DATA_ZHA, DATA_ZHA_GATEWAY from homeassistant.components.zha.core.device import ZHADevice +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.components.zha.diagnostics import KEYS_TO_REDACT from homeassistant.const import Platform from homeassistant.core import HomeAssistant @@ -65,7 +65,7 @@ async def test_diagnostics_for_config_entry( """Test diagnostics for config entry.""" await zha_device_joined(zigpy_device) - gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] + gateway = get_zha_gateway(hass) scan = {c: c for c in range(11, 26 + 1)} with patch.object(gateway.application_controller, "energy_scan", return_value=scan): diff --git a/tests/components/zha/test_discover.py b/tests/components/zha/test_discover.py index e0785601b4f..768f974d928 100644 --- a/tests/components/zha/test_discover.py +++ b/tests/components/zha/test_discover.py @@ -20,12 +20,12 @@ import homeassistant.components.zha.core.const as zha_const from homeassistant.components.zha.core.device import ZHADevice import homeassistant.components.zha.core.discovery as disc from homeassistant.components.zha.core.endpoint import Endpoint +from homeassistant.components.zha.core.helpers import get_zha_gateway import homeassistant.components.zha.core.registries as zha_regs from homeassistant.const import Platform from homeassistant.core import HomeAssistant import homeassistant.helpers.entity_registry as er -from .common import get_zha_gateway from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .zha_devices_list import ( DEV_SIG_ATTRIBUTES, diff --git a/tests/components/zha/test_fan.py b/tests/components/zha/test_fan.py index 3d0b065ab18..81ab1c2e0f5 100644 --- a/tests/components/zha/test_fan.py +++ b/tests/components/zha/test_fan.py @@ -21,6 +21,7 @@ from homeassistant.components.fan import ( from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.discovery import GROUP_PROBE from homeassistant.components.zha.core.group import GroupMember +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.components.zha.fan import ( PRESET_MODE_AUTO, PRESET_MODE_ON, @@ -45,7 +46,6 @@ from .common import ( async_test_rejoin, async_wait_for_updates, find_entity_id, - get_zha_gateway, send_attributes_report, ) from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE diff --git a/tests/components/zha/test_gateway.py b/tests/components/zha/test_gateway.py index 0f791a08955..214bfcad9f0 100644 --- a/tests/components/zha/test_gateway.py +++ b/tests/components/zha/test_gateway.py @@ -11,11 +11,12 @@ import zigpy.zcl.clusters.lighting as lighting from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.group import GroupMember +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady -from .common import async_find_group_entity_id, get_zha_gateway +from .common import async_find_group_entity_id from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" diff --git a/tests/components/zha/test_light.py b/tests/components/zha/test_light.py index c1f5cf04e35..da91340b864 100644 --- a/tests/components/zha/test_light.py +++ b/tests/components/zha/test_light.py @@ -20,9 +20,11 @@ from homeassistant.components.zha.core.const import ( ZHA_OPTIONS, ) from homeassistant.components.zha.core.group import GroupMember +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.components.zha.light import FLASH_EFFECTS from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er import homeassistant.util.dt as dt_util from .common import ( @@ -32,7 +34,6 @@ from .common import ( async_test_rejoin, async_wait_for_updates, find_entity_id, - get_zha_gateway, patch_zha_config, send_attributes_report, update_attribute_cache, @@ -1781,7 +1782,8 @@ async def test_zha_group_light_entity( assert device_3_entity_id not in zha_group.member_entity_ids # make sure the entity registry entry is still there - assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is not None + entity_registry = er.async_get(hass) + assert entity_registry.async_get(group_entity_id) is not None # add a member back and ensure that the group entity was created again await zha_group.async_add_members([GroupMember(device_light_3.ieee, 1)]) @@ -1811,10 +1813,10 @@ async def test_zha_group_light_entity( assert len(zha_group.members) == 3 # remove the group and ensure that there is no entity and that the entity registry is cleaned up - assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is not None + assert entity_registry.async_get(group_entity_id) is not None await zha_gateway.async_remove_zigpy_group(zha_group.group_id) assert hass.states.get(group_entity_id) is None - assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is None + assert entity_registry.async_get(group_entity_id) is None @patch( @@ -1914,7 +1916,8 @@ async def test_group_member_assume_state( assert hass.states.get(group_entity_id).state == STATE_OFF # remove the group and ensure that there is no entity and that the entity registry is cleaned up - assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is not None + entity_registry = er.async_get(hass) + assert entity_registry.async_get(group_entity_id) is not None await zha_gateway.async_remove_zigpy_group(zha_group.group_id) assert hass.states.get(group_entity_id) is None - assert zha_gateway.ha_entity_registry.async_get(group_entity_id) is None + assert entity_registry.async_get(group_entity_id) is None diff --git a/tests/components/zha/test_silabs_multiprotocol.py b/tests/components/zha/test_silabs_multiprotocol.py index beae0230901..4d11ae81b08 100644 --- a/tests/components/zha/test_silabs_multiprotocol.py +++ b/tests/components/zha/test_silabs_multiprotocol.py @@ -9,7 +9,8 @@ import zigpy.backups import zigpy.state from homeassistant.components import zha -from homeassistant.components.zha import api, silabs_multiprotocol +from homeassistant.components.zha import silabs_multiprotocol +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.core import HomeAssistant if TYPE_CHECKING: @@ -36,7 +37,7 @@ async def test_async_get_channel_missing( """Test reading channel with an inactive ZHA installation, no valid channel.""" await setup_zha() - gateway = api._get_gateway(hass) + gateway = get_zha_gateway(hass) await zha.async_unload_entry(hass, gateway.config_entry) # Network settings were never loaded for whatever reason diff --git a/tests/components/zha/test_switch.py b/tests/components/zha/test_switch.py index fe7450eff67..b07b34763d1 100644 --- a/tests/components/zha/test_switch.py +++ b/tests/components/zha/test_switch.py @@ -19,6 +19,7 @@ import zigpy.zcl.foundation as zcl_f from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.components.zha.core.group import GroupMember +from homeassistant.components.zha.core.helpers import get_zha_gateway from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -30,7 +31,6 @@ from .common import ( async_test_rejoin, async_wait_for_updates, find_entity_id, - get_zha_gateway, send_attributes_report, ) from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE diff --git a/tests/components/zha/test_websocket_api.py b/tests/components/zha/test_websocket_api.py index 740ffd6c06c..b0e15a01318 100644 --- a/tests/components/zha/test_websocket_api.py +++ b/tests/components/zha/test_websocket_api.py @@ -940,6 +940,7 @@ async def test_websocket_bind_unbind_devices( @pytest.mark.parametrize("command_type", ["bind", "unbind"]) async def test_websocket_bind_unbind_group( command_type: str, + hass: HomeAssistant, app_controller: ControllerApplication, zha_client, ) -> None: @@ -947,8 +948,9 @@ async def test_websocket_bind_unbind_group( test_group_id = 0x0001 gateway_mock = MagicMock() + with patch( - "homeassistant.components.zha.websocket_api.get_gateway", + "homeassistant.components.zha.websocket_api.get_zha_gateway", return_value=gateway_mock, ): device_mock = MagicMock()