diff --git a/homeassistant/components/zha/core/cluster_handlers/__init__.py b/homeassistant/components/zha/core/cluster_handlers/__init__.py index 2b78c90aa19..00439343e81 100644 --- a/homeassistant/components/zha/core/cluster_handlers/__init__.py +++ b/homeassistant/components/zha/core/cluster_handlers/__init__.py @@ -42,7 +42,7 @@ from ..const import ( ZHA_CLUSTER_HANDLER_MSG_DATA, ZHA_CLUSTER_HANDLER_READS_PER_REQ, ) -from ..helpers import LogMixin, retryable_req, safe_read +from ..helpers import LogMixin, safe_read if TYPE_CHECKING: from ..endpoint import Endpoint @@ -362,7 +362,6 @@ class ClusterHandler(LogMixin): self.debug("skipping cluster handler configuration") self._status = ClusterHandlerStatus.CONFIGURED - @retryable_req(delays=(1, 1, 3)) async def async_initialize(self, from_cache: bool) -> None: """Initialize cluster handler.""" if not from_cache and self._endpoint.device.skip_configuration: diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index 1a3d3a2da1f..468e89fbbf0 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -592,12 +592,17 @@ class ZHADevice(LogMixin): self.debug("started initialization") await self._zdo_handler.async_initialize(from_cache) self._zdo_handler.debug("'async_initialize' stage succeeded") - await asyncio.gather( - *( - endpoint.async_initialize(from_cache) - for endpoint in self._endpoints.values() - ) - ) + + # We intentionally do not use `gather` here! This is so that if, for example, + # three `device.async_initialize()`s are spawned, only three concurrent requests + # will ever be in flight at once. Startup concurrency is managed at the device + # level. + for endpoint in self._endpoints.values(): + try: + await endpoint.async_initialize(from_cache) + except Exception: # pylint: disable=broad-exception-caught + self.debug("Failed to initialize endpoint", exc_info=True) + self.debug("power source: %s", self.power_source) self.status = DeviceStatus.INITIALIZED self.debug("completed initialization") diff --git a/homeassistant/components/zha/core/endpoint.py b/homeassistant/components/zha/core/endpoint.py index 04c253128ee..4dbfccf6f25 100644 --- a/homeassistant/components/zha/core/endpoint.py +++ b/homeassistant/components/zha/core/endpoint.py @@ -2,7 +2,8 @@ from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Awaitable, Callable +import functools import logging from typing import TYPE_CHECKING, Any, Final, TypeVar @@ -11,6 +12,7 @@ from zigpy.typing import EndpointType as ZigpyEndpointType from homeassistant.const import Platform from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.util.async_ import gather_with_limited_concurrency from . import const, discovery, registries from .cluster_handlers import ClusterHandler @@ -169,20 +171,32 @@ class Endpoint: async def async_initialize(self, from_cache: bool = False) -> None: """Initialize claimed cluster handlers.""" - await self._execute_handler_tasks("async_initialize", from_cache) + await self._execute_handler_tasks( + "async_initialize", from_cache, max_concurrency=1 + ) async def async_configure(self) -> None: """Configure claimed cluster handlers.""" await self._execute_handler_tasks("async_configure") - async def _execute_handler_tasks(self, func_name: str, *args: Any) -> None: + async def _execute_handler_tasks( + self, func_name: str, *args: Any, max_concurrency: int | None = None + ) -> None: """Add a throttled cluster handler task and swallow exceptions.""" cluster_handlers = [ *self.claimed_cluster_handlers.values(), *self.client_cluster_handlers.values(), ] tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers] - results = await asyncio.gather(*tasks, return_exceptions=True) + + gather: Callable[..., Awaitable] + + if max_concurrency is None: + gather = asyncio.gather + else: + gather = functools.partial(gather_with_limited_concurrency, max_concurrency) + + results = await gather(*tasks, return_exceptions=True) for cluster_handler, outcome in zip(cluster_handlers, results): if isinstance(outcome, Exception): cluster_handler.warning( diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 3efdc77934a..cca8aa93e99 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -11,7 +11,7 @@ import itertools import logging import re import time -from typing import TYPE_CHECKING, Any, NamedTuple, Self +from typing import TYPE_CHECKING, Any, NamedTuple, Self, cast from zigpy.application import ControllerApplication from zigpy.config import ( @@ -36,6 +36,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import ConfigType +from homeassistant.util.async_ import gather_with_limited_concurrency from . import discovery from .const import ( @@ -292,6 +293,39 @@ class ZHAGateway: # entity registry tied to the devices discovery.GROUP_PROBE.discover_group_entities(zha_group) + @property + def radio_concurrency(self) -> int: + """Maximum configured radio concurrency.""" + return self.application_controller._concurrent_requests_semaphore.max_value # pylint: disable=protected-access + + async def async_fetch_updated_state_mains(self) -> None: + """Fetch updated state for mains powered devices.""" + _LOGGER.debug("Fetching current state for mains powered devices") + + now = time.time() + + # Only delay startup to poll mains-powered devices that are online + online_devices = [ + dev + for dev in self.devices.values() + if dev.is_mains_powered + and dev.last_seen is not None + and (now - dev.last_seen) < dev.consider_unavailable_time + ] + + # Prioritize devices that have recently been contacted + online_devices.sort(key=lambda dev: cast(float, dev.last_seen), reverse=True) + + # Make sure that we always leave slots for non-startup requests + max_poll_concurrency = max(1, self.radio_concurrency - 4) + + await gather_with_limited_concurrency( + max_poll_concurrency, + *(dev.async_initialize(from_cache=False) for dev in online_devices), + ) + + _LOGGER.debug("completed fetching current state for mains powered devices") + async def async_initialize_devices_and_entities(self) -> None: """Initialize devices and load entities.""" @@ -302,17 +336,8 @@ class ZHAGateway: async def fetch_updated_state() -> None: """Fetch updated state for mains powered devices.""" - _LOGGER.debug("Fetching current state for mains powered devices") - await asyncio.gather( - *( - dev.async_initialize(from_cache=False) - for dev in self.devices.values() - if dev.is_mains_powered - ) - ) - _LOGGER.debug( - "completed fetching current state for mains powered devices - allowing polled requests" - ) + await self.async_fetch_updated_state_mains() + _LOGGER.debug("Allowing polled requests") self.hass.data[DATA_ZHA].allow_polling = True # background the fetching of state for mains powered devices diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py index 8e518d805c6..6f0167827e8 100644 --- a/homeassistant/components/zha/core/helpers.py +++ b/homeassistant/components/zha/core/helpers.py @@ -5,19 +5,15 @@ https://home-assistant.io/integrations/zha/ """ from __future__ import annotations -import asyncio import binascii import collections -from collections.abc import Callable, Collection, Coroutine, Iterator +from collections.abc import Callable, Iterator import dataclasses from dataclasses import dataclass import enum -import functools -import itertools import logging -from random import uniform import re -from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar import voluptuous as vol import zigpy.exceptions @@ -322,58 +318,6 @@ class LogMixin: return self.log(logging.ERROR, msg, *args, **kwargs) -def retryable_req( - delays: Collection[float] = (1, 5, 10, 15, 30, 60, 120, 180, 360, 600, 900, 1800), - raise_: bool = False, -) -> Callable[ - [Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]]], - Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]], -]: - """Make a method with ZCL requests retryable. - - This adds delays keyword argument to function. - len(delays) is number of tries. - raise_ if the final attempt should raise the exception. - """ - - def decorator( - func: Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]], - ) -> Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]]: - @functools.wraps(func) - async def wrapper( - cluster_handler: _ClusterHandlerT, *args: _P.args, **kwargs: _P.kwargs - ) -> _R | None: - exceptions = (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) - try_count, errors = 1, [] - for delay in itertools.chain(delays, [None]): - try: - return await func(cluster_handler, *args, **kwargs) - except exceptions as ex: - errors.append(ex) - if delay: - delay = uniform(delay * 0.75, delay * 1.25) - cluster_handler.debug( - "%s: retryable request #%d failed: %s. Retrying in %ss", - func.__name__, - try_count, - ex, - round(delay, 1), - ) - try_count += 1 - await asyncio.sleep(delay) - else: - cluster_handler.warning( - "%s: all attempts have failed: %s", func.__name__, errors - ) - if raise_: - raise - return None - - return wrapper - - return decorator - - def convert_install_code(value: str) -> bytes: """Convert string to install code bytes and validate length.""" diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index 55405d0a51c..a30c6f35052 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -135,7 +135,7 @@ def _wrap_mock_instance(obj: Any) -> MagicMock: real_attr = getattr(obj, attr_name) mock_attr = getattr(mock, attr_name) - if callable(real_attr): + if callable(real_attr) and not hasattr(real_attr, "__aenter__"): mock_attr.side_effect = real_attr else: setattr(mock, attr_name, real_attr) diff --git a/tests/components/zha/test_gateway.py b/tests/components/zha/test_gateway.py index 9c3cf7aa2f8..f19ed9bd4a9 100644 --- a/tests/components/zha/test_gateway.py +++ b/tests/components/zha/test_gateway.py @@ -1,12 +1,14 @@ """Test ZHA Gateway.""" import asyncio -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest from zigpy.application import ControllerApplication import zigpy.profiles.zha as zha +import zigpy.types import zigpy.zcl.clusters.general as general import zigpy.zcl.clusters.lighting as lighting +import zigpy.zdo.types from homeassistant.components.zha.core.gateway import ZHAGateway from homeassistant.components.zha.core.group import GroupMember @@ -321,3 +323,81 @@ async def test_single_reload_on_multiple_connection_loss( assert len(mock_reload.mock_calls) == 1 await hass.async_block_till_done() + + +@pytest.mark.parametrize("radio_concurrency", [1, 2, 8]) +async def test_startup_concurrency_limit( + radio_concurrency: int, + hass: HomeAssistant, + zigpy_app_controller: ControllerApplication, + config_entry: MockConfigEntry, + zigpy_device_mock, +): + """Test ZHA gateway limits concurrency on startup.""" + config_entry.add_to_hass(hass) + zha_gateway = ZHAGateway(hass, {}, config_entry) + + with patch( + "bellows.zigbee.application.ControllerApplication.new", + return_value=zigpy_app_controller, + ): + await zha_gateway.async_initialize() + + for i in range(50): + zigpy_dev = zigpy_device_mock( + { + 1: { + SIG_EP_INPUT: [ + general.OnOff.cluster_id, + general.LevelControl.cluster_id, + lighting.Color.cluster_id, + general.Groups.cluster_id, + ], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.COLOR_DIMMABLE_LIGHT, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ieee=f"11:22:33:44:{i:08x}", + nwk=0x1234 + i, + ) + zigpy_dev.node_desc.mac_capability_flags |= ( + zigpy.zdo.types.NodeDescriptor.MACCapabilityFlags.MainsPowered + ) + + zha_gateway._async_get_or_create_device(zigpy_dev, restored=True) + + # Keep track of request concurrency during initialization + current_concurrency = 0 + concurrencies = [] + + async def mock_send_packet(*args, **kwargs): + nonlocal current_concurrency + + current_concurrency += 1 + concurrencies.append(current_concurrency) + + await asyncio.sleep(0.001) + + current_concurrency -= 1 + concurrencies.append(current_concurrency) + + type(zha_gateway).radio_concurrency = PropertyMock(return_value=radio_concurrency) + assert zha_gateway.radio_concurrency == radio_concurrency + + with patch( + "homeassistant.components.zha.core.device.ZHADevice.async_initialize", + side_effect=mock_send_packet, + ): + await zha_gateway.async_fetch_updated_state_mains() + + await zha_gateway.shutdown() + + # Make sure concurrency was always limited + assert current_concurrency == 0 + assert min(concurrencies) == 0 + + if radio_concurrency > 1: + assert 1 <= max(concurrencies) < zha_gateway.radio_concurrency + else: + assert 1 == max(concurrencies) == zha_gateway.radio_concurrency