mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Speed up ZHA initialization and improve startup responsiveness (#108103)
* Limit concurrency of startup traffic to allow for interactive usage * Drop `retryable_req`, we already have request retrying * Oops, `min` -> `max` * Add a comment describing why `async_initialize` is not concurrent * Fix existing unit tests * Break out fetching mains state into its own function to unit test
This commit is contained in:
parent
3ae858e3bf
commit
867caab70a
@ -42,7 +42,7 @@ from ..const import (
|
|||||||
ZHA_CLUSTER_HANDLER_MSG_DATA,
|
ZHA_CLUSTER_HANDLER_MSG_DATA,
|
||||||
ZHA_CLUSTER_HANDLER_READS_PER_REQ,
|
ZHA_CLUSTER_HANDLER_READS_PER_REQ,
|
||||||
)
|
)
|
||||||
from ..helpers import LogMixin, retryable_req, safe_read
|
from ..helpers import LogMixin, safe_read
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..endpoint import Endpoint
|
from ..endpoint import Endpoint
|
||||||
@ -362,7 +362,6 @@ class ClusterHandler(LogMixin):
|
|||||||
self.debug("skipping cluster handler configuration")
|
self.debug("skipping cluster handler configuration")
|
||||||
self._status = ClusterHandlerStatus.CONFIGURED
|
self._status = ClusterHandlerStatus.CONFIGURED
|
||||||
|
|
||||||
@retryable_req(delays=(1, 1, 3))
|
|
||||||
async def async_initialize(self, from_cache: bool) -> None:
|
async def async_initialize(self, from_cache: bool) -> None:
|
||||||
"""Initialize cluster handler."""
|
"""Initialize cluster handler."""
|
||||||
if not from_cache and self._endpoint.device.skip_configuration:
|
if not from_cache and self._endpoint.device.skip_configuration:
|
||||||
|
@ -592,12 +592,17 @@ class ZHADevice(LogMixin):
|
|||||||
self.debug("started initialization")
|
self.debug("started initialization")
|
||||||
await self._zdo_handler.async_initialize(from_cache)
|
await self._zdo_handler.async_initialize(from_cache)
|
||||||
self._zdo_handler.debug("'async_initialize' stage succeeded")
|
self._zdo_handler.debug("'async_initialize' stage succeeded")
|
||||||
await asyncio.gather(
|
|
||||||
*(
|
# We intentionally do not use `gather` here! This is so that if, for example,
|
||||||
endpoint.async_initialize(from_cache)
|
# three `device.async_initialize()`s are spawned, only three concurrent requests
|
||||||
for endpoint in self._endpoints.values()
|
# will ever be in flight at once. Startup concurrency is managed at the device
|
||||||
)
|
# level.
|
||||||
)
|
for endpoint in self._endpoints.values():
|
||||||
|
try:
|
||||||
|
await endpoint.async_initialize(from_cache)
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
self.debug("Failed to initialize endpoint", exc_info=True)
|
||||||
|
|
||||||
self.debug("power source: %s", self.power_source)
|
self.debug("power source: %s", self.power_source)
|
||||||
self.status = DeviceStatus.INITIALIZED
|
self.status = DeviceStatus.INITIALIZED
|
||||||
self.debug("completed initialization")
|
self.debug("completed initialization")
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Final, TypeVar
|
from typing import TYPE_CHECKING, Any, Final, TypeVar
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ from zigpy.typing import EndpointType as ZigpyEndpointType
|
|||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||||
|
from homeassistant.util.async_ import gather_with_limited_concurrency
|
||||||
|
|
||||||
from . import const, discovery, registries
|
from . import const, discovery, registries
|
||||||
from .cluster_handlers import ClusterHandler
|
from .cluster_handlers import ClusterHandler
|
||||||
@ -169,20 +171,32 @@ class Endpoint:
|
|||||||
|
|
||||||
async def async_initialize(self, from_cache: bool = False) -> None:
|
async def async_initialize(self, from_cache: bool = False) -> None:
|
||||||
"""Initialize claimed cluster handlers."""
|
"""Initialize claimed cluster handlers."""
|
||||||
await self._execute_handler_tasks("async_initialize", from_cache)
|
await self._execute_handler_tasks(
|
||||||
|
"async_initialize", from_cache, max_concurrency=1
|
||||||
|
)
|
||||||
|
|
||||||
async def async_configure(self) -> None:
|
async def async_configure(self) -> None:
|
||||||
"""Configure claimed cluster handlers."""
|
"""Configure claimed cluster handlers."""
|
||||||
await self._execute_handler_tasks("async_configure")
|
await self._execute_handler_tasks("async_configure")
|
||||||
|
|
||||||
async def _execute_handler_tasks(self, func_name: str, *args: Any) -> None:
|
async def _execute_handler_tasks(
|
||||||
|
self, func_name: str, *args: Any, max_concurrency: int | None = None
|
||||||
|
) -> None:
|
||||||
"""Add a throttled cluster handler task and swallow exceptions."""
|
"""Add a throttled cluster handler task and swallow exceptions."""
|
||||||
cluster_handlers = [
|
cluster_handlers = [
|
||||||
*self.claimed_cluster_handlers.values(),
|
*self.claimed_cluster_handlers.values(),
|
||||||
*self.client_cluster_handlers.values(),
|
*self.client_cluster_handlers.values(),
|
||||||
]
|
]
|
||||||
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]
|
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
gather: Callable[..., Awaitable]
|
||||||
|
|
||||||
|
if max_concurrency is None:
|
||||||
|
gather = asyncio.gather
|
||||||
|
else:
|
||||||
|
gather = functools.partial(gather_with_limited_concurrency, max_concurrency)
|
||||||
|
|
||||||
|
results = await gather(*tasks, return_exceptions=True)
|
||||||
for cluster_handler, outcome in zip(cluster_handlers, results):
|
for cluster_handler, outcome in zip(cluster_handlers, results):
|
||||||
if isinstance(outcome, Exception):
|
if isinstance(outcome, Exception):
|
||||||
cluster_handler.warning(
|
cluster_handler.warning(
|
||||||
|
@ -11,7 +11,7 @@ import itertools
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple, Self
|
from typing import TYPE_CHECKING, Any, NamedTuple, Self, cast
|
||||||
|
|
||||||
from zigpy.application import ControllerApplication
|
from zigpy.application import ControllerApplication
|
||||||
from zigpy.config import (
|
from zigpy.config import (
|
||||||
@ -36,6 +36,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
|
|||||||
from homeassistant.helpers.device_registry import DeviceInfo
|
from homeassistant.helpers.device_registry import DeviceInfo
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
from homeassistant.util.async_ import gather_with_limited_concurrency
|
||||||
|
|
||||||
from . import discovery
|
from . import discovery
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -292,6 +293,39 @@ class ZHAGateway:
|
|||||||
# entity registry tied to the devices
|
# entity registry tied to the devices
|
||||||
discovery.GROUP_PROBE.discover_group_entities(zha_group)
|
discovery.GROUP_PROBE.discover_group_entities(zha_group)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def radio_concurrency(self) -> int:
|
||||||
|
"""Maximum configured radio concurrency."""
|
||||||
|
return self.application_controller._concurrent_requests_semaphore.max_value # pylint: disable=protected-access
|
||||||
|
|
||||||
|
async def async_fetch_updated_state_mains(self) -> None:
|
||||||
|
"""Fetch updated state for mains powered devices."""
|
||||||
|
_LOGGER.debug("Fetching current state for mains powered devices")
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Only delay startup to poll mains-powered devices that are online
|
||||||
|
online_devices = [
|
||||||
|
dev
|
||||||
|
for dev in self.devices.values()
|
||||||
|
if dev.is_mains_powered
|
||||||
|
and dev.last_seen is not None
|
||||||
|
and (now - dev.last_seen) < dev.consider_unavailable_time
|
||||||
|
]
|
||||||
|
|
||||||
|
# Prioritize devices that have recently been contacted
|
||||||
|
online_devices.sort(key=lambda dev: cast(float, dev.last_seen), reverse=True)
|
||||||
|
|
||||||
|
# Make sure that we always leave slots for non-startup requests
|
||||||
|
max_poll_concurrency = max(1, self.radio_concurrency - 4)
|
||||||
|
|
||||||
|
await gather_with_limited_concurrency(
|
||||||
|
max_poll_concurrency,
|
||||||
|
*(dev.async_initialize(from_cache=False) for dev in online_devices),
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug("completed fetching current state for mains powered devices")
|
||||||
|
|
||||||
async def async_initialize_devices_and_entities(self) -> None:
|
async def async_initialize_devices_and_entities(self) -> None:
|
||||||
"""Initialize devices and load entities."""
|
"""Initialize devices and load entities."""
|
||||||
|
|
||||||
@ -302,17 +336,8 @@ class ZHAGateway:
|
|||||||
|
|
||||||
async def fetch_updated_state() -> None:
|
async def fetch_updated_state() -> None:
|
||||||
"""Fetch updated state for mains powered devices."""
|
"""Fetch updated state for mains powered devices."""
|
||||||
_LOGGER.debug("Fetching current state for mains powered devices")
|
await self.async_fetch_updated_state_mains()
|
||||||
await asyncio.gather(
|
_LOGGER.debug("Allowing polled requests")
|
||||||
*(
|
|
||||||
dev.async_initialize(from_cache=False)
|
|
||||||
for dev in self.devices.values()
|
|
||||||
if dev.is_mains_powered
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_LOGGER.debug(
|
|
||||||
"completed fetching current state for mains powered devices - allowing polled requests"
|
|
||||||
)
|
|
||||||
self.hass.data[DATA_ZHA].allow_polling = True
|
self.hass.data[DATA_ZHA].allow_polling = True
|
||||||
|
|
||||||
# background the fetching of state for mains powered devices
|
# background the fetching of state for mains powered devices
|
||||||
|
@ -5,19 +5,15 @@ https://home-assistant.io/integrations/zha/
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import binascii
|
import binascii
|
||||||
import collections
|
import collections
|
||||||
from collections.abc import Callable, Collection, Coroutine, Iterator
|
from collections.abc import Callable, Iterator
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import enum
|
import enum
|
||||||
import functools
|
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
from random import uniform
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
import zigpy.exceptions
|
import zigpy.exceptions
|
||||||
@ -322,58 +318,6 @@ class LogMixin:
|
|||||||
return self.log(logging.ERROR, msg, *args, **kwargs)
|
return self.log(logging.ERROR, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def retryable_req(
|
|
||||||
delays: Collection[float] = (1, 5, 10, 15, 30, 60, 120, 180, 360, 600, 900, 1800),
|
|
||||||
raise_: bool = False,
|
|
||||||
) -> Callable[
|
|
||||||
[Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]]],
|
|
||||||
Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]],
|
|
||||||
]:
|
|
||||||
"""Make a method with ZCL requests retryable.
|
|
||||||
|
|
||||||
This adds delays keyword argument to function.
|
|
||||||
len(delays) is number of tries.
|
|
||||||
raise_ if the final attempt should raise the exception.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(
|
|
||||||
func: Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]],
|
|
||||||
) -> Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]]:
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(
|
|
||||||
cluster_handler: _ClusterHandlerT, *args: _P.args, **kwargs: _P.kwargs
|
|
||||||
) -> _R | None:
|
|
||||||
exceptions = (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError)
|
|
||||||
try_count, errors = 1, []
|
|
||||||
for delay in itertools.chain(delays, [None]):
|
|
||||||
try:
|
|
||||||
return await func(cluster_handler, *args, **kwargs)
|
|
||||||
except exceptions as ex:
|
|
||||||
errors.append(ex)
|
|
||||||
if delay:
|
|
||||||
delay = uniform(delay * 0.75, delay * 1.25)
|
|
||||||
cluster_handler.debug(
|
|
||||||
"%s: retryable request #%d failed: %s. Retrying in %ss",
|
|
||||||
func.__name__,
|
|
||||||
try_count,
|
|
||||||
ex,
|
|
||||||
round(delay, 1),
|
|
||||||
)
|
|
||||||
try_count += 1
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
else:
|
|
||||||
cluster_handler.warning(
|
|
||||||
"%s: all attempts have failed: %s", func.__name__, errors
|
|
||||||
)
|
|
||||||
if raise_:
|
|
||||||
raise
|
|
||||||
return None
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def convert_install_code(value: str) -> bytes:
|
def convert_install_code(value: str) -> bytes:
|
||||||
"""Convert string to install code bytes and validate length."""
|
"""Convert string to install code bytes and validate length."""
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ def _wrap_mock_instance(obj: Any) -> MagicMock:
|
|||||||
real_attr = getattr(obj, attr_name)
|
real_attr = getattr(obj, attr_name)
|
||||||
mock_attr = getattr(mock, attr_name)
|
mock_attr = getattr(mock, attr_name)
|
||||||
|
|
||||||
if callable(real_attr):
|
if callable(real_attr) and not hasattr(real_attr, "__aenter__"):
|
||||||
mock_attr.side_effect = real_attr
|
mock_attr.side_effect = real_attr
|
||||||
else:
|
else:
|
||||||
setattr(mock, attr_name, real_attr)
|
setattr(mock, attr_name, real_attr)
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
"""Test ZHA Gateway."""
|
"""Test ZHA Gateway."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from zigpy.application import ControllerApplication
|
from zigpy.application import ControllerApplication
|
||||||
import zigpy.profiles.zha as zha
|
import zigpy.profiles.zha as zha
|
||||||
|
import zigpy.types
|
||||||
import zigpy.zcl.clusters.general as general
|
import zigpy.zcl.clusters.general as general
|
||||||
import zigpy.zcl.clusters.lighting as lighting
|
import zigpy.zcl.clusters.lighting as lighting
|
||||||
|
import zigpy.zdo.types
|
||||||
|
|
||||||
from homeassistant.components.zha.core.gateway import ZHAGateway
|
from homeassistant.components.zha.core.gateway import ZHAGateway
|
||||||
from homeassistant.components.zha.core.group import GroupMember
|
from homeassistant.components.zha.core.group import GroupMember
|
||||||
@ -321,3 +323,81 @@ async def test_single_reload_on_multiple_connection_loss(
|
|||||||
assert len(mock_reload.mock_calls) == 1
|
assert len(mock_reload.mock_calls) == 1
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("radio_concurrency", [1, 2, 8])
|
||||||
|
async def test_startup_concurrency_limit(
|
||||||
|
radio_concurrency: int,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
zigpy_app_controller: ControllerApplication,
|
||||||
|
config_entry: MockConfigEntry,
|
||||||
|
zigpy_device_mock,
|
||||||
|
):
|
||||||
|
"""Test ZHA gateway limits concurrency on startup."""
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
zha_gateway = ZHAGateway(hass, {}, config_entry)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"bellows.zigbee.application.ControllerApplication.new",
|
||||||
|
return_value=zigpy_app_controller,
|
||||||
|
):
|
||||||
|
await zha_gateway.async_initialize()
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
zigpy_dev = zigpy_device_mock(
|
||||||
|
{
|
||||||
|
1: {
|
||||||
|
SIG_EP_INPUT: [
|
||||||
|
general.OnOff.cluster_id,
|
||||||
|
general.LevelControl.cluster_id,
|
||||||
|
lighting.Color.cluster_id,
|
||||||
|
general.Groups.cluster_id,
|
||||||
|
],
|
||||||
|
SIG_EP_OUTPUT: [],
|
||||||
|
SIG_EP_TYPE: zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||||
|
SIG_EP_PROFILE: zha.PROFILE_ID,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ieee=f"11:22:33:44:{i:08x}",
|
||||||
|
nwk=0x1234 + i,
|
||||||
|
)
|
||||||
|
zigpy_dev.node_desc.mac_capability_flags |= (
|
||||||
|
zigpy.zdo.types.NodeDescriptor.MACCapabilityFlags.MainsPowered
|
||||||
|
)
|
||||||
|
|
||||||
|
zha_gateway._async_get_or_create_device(zigpy_dev, restored=True)
|
||||||
|
|
||||||
|
# Keep track of request concurrency during initialization
|
||||||
|
current_concurrency = 0
|
||||||
|
concurrencies = []
|
||||||
|
|
||||||
|
async def mock_send_packet(*args, **kwargs):
|
||||||
|
nonlocal current_concurrency
|
||||||
|
|
||||||
|
current_concurrency += 1
|
||||||
|
concurrencies.append(current_concurrency)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
current_concurrency -= 1
|
||||||
|
concurrencies.append(current_concurrency)
|
||||||
|
|
||||||
|
type(zha_gateway).radio_concurrency = PropertyMock(return_value=radio_concurrency)
|
||||||
|
assert zha_gateway.radio_concurrency == radio_concurrency
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.zha.core.device.ZHADevice.async_initialize",
|
||||||
|
side_effect=mock_send_packet,
|
||||||
|
):
|
||||||
|
await zha_gateway.async_fetch_updated_state_mains()
|
||||||
|
|
||||||
|
await zha_gateway.shutdown()
|
||||||
|
|
||||||
|
# Make sure concurrency was always limited
|
||||||
|
assert current_concurrency == 0
|
||||||
|
assert min(concurrencies) == 0
|
||||||
|
|
||||||
|
if radio_concurrency > 1:
|
||||||
|
assert 1 <= max(concurrencies) < zha_gateway.radio_concurrency
|
||||||
|
else:
|
||||||
|
assert 1 == max(concurrencies) == zha_gateway.radio_concurrency
|
||||||
|
Loading…
x
Reference in New Issue
Block a user