Refactor entity_platform polling to avoid double time fetch (#116877)

* Refactor entity_platform polling to avoid double time fetch

Replace async_track_time_interval with loop.call_later
to avoid the useless time fetch every time the listener
fired since we always throw it away

* fix test
This commit is contained in:
J. Nick Koston 2024-05-05 15:28:01 -05:00 committed by GitHub
parent 76cd498c44
commit b41b1bb998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 38 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from contextvars import ContextVar
from datetime import datetime, timedelta
from datetime import timedelta
from functools import partial
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Protocol
@ -43,7 +43,7 @@ from . import (
translation,
)
from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider
from .event import async_call_later, async_track_time_interval
from .event import async_call_later
from .issue_registry import IssueSeverity, async_create_issue
from .typing import UNDEFINED, ConfigType, DiscoveryInfoType
@ -125,6 +125,7 @@ class EntityPlatform:
self.platform_name = platform_name
self.platform = platform
self.scan_interval = scan_interval
self.scan_interval_seconds = scan_interval.total_seconds()
self.entity_namespace = entity_namespace
self.config_entry: config_entries.ConfigEntry | None = None
# Storage for entities for this specific platform only
@ -138,7 +139,7 @@ class EntityPlatform:
# Stop tracking tasks after setup is completed
self._setup_complete = False
# Method to cancel the state change listener
self._async_unsub_polling: CALLBACK_TYPE | None = None
self._async_polling_timer: asyncio.TimerHandle | None = None
# Method to cancel the retry of setup
self._async_cancel_retry_setup: CALLBACK_TYPE | None = None
self._process_updates: asyncio.Lock | None = None
@ -630,7 +631,7 @@ class EntityPlatform:
if (
(self.config_entry and self.config_entry.pref_disable_polling)
or self._async_unsub_polling is not None
or self._async_polling_timer is not None
or not any(
# Entity may have failed to add or called `add_to_platform_abort`
# so we check if the entity is in self.entities before
@ -644,26 +645,28 @@ class EntityPlatform:
):
return
self._async_unsub_polling = async_track_time_interval(
self.hass,
self._async_polling_timer = self.hass.loop.call_later(
self.scan_interval_seconds,
self._async_handle_interval_callback,
self.scan_interval,
name=f"EntityPlatform poll {self.domain}.{self.platform_name}",
)
@callback
def _async_handle_interval_callback(self, now: datetime) -> None:
def _async_handle_interval_callback(self) -> None:
"""Update all the entity states in a single platform."""
self._async_polling_timer = self.hass.loop.call_later(
self.scan_interval_seconds,
self._async_handle_interval_callback,
)
if self.config_entry:
self.config_entry.async_create_background_task(
self.hass,
self._update_entity_states(now),
self._async_update_entity_states(),
name=f"EntityPlatform poll {self.domain}.{self.platform_name}",
eager_start=True,
)
else:
self.hass.async_create_background_task(
self._update_entity_states(now),
self._async_update_entity_states(),
name=f"EntityPlatform poll {self.domain}.{self.platform_name}",
eager_start=True,
)
@ -919,9 +922,9 @@ class EntityPlatform:
@callback
def async_unsub_polling(self) -> None:
"""Stop polling."""
if self._async_unsub_polling is not None:
self._async_unsub_polling()
self._async_unsub_polling = None
if self._async_polling_timer is not None:
self._async_polling_timer.cancel()
self._async_polling_timer = None
@callback
def async_prepare(self) -> None:
@ -943,11 +946,10 @@ class EntityPlatform:
await self.entities[entity_id].async_remove()
# Clean up polling job if no longer needed
if self._async_unsub_polling is not None and not any(
if self._async_polling_timer is not None and not any(
entity.should_poll for entity in self.entities.values()
):
self._async_unsub_polling()
self._async_unsub_polling = None
self.async_unsub_polling()
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
@ -998,7 +1000,7 @@ class EntityPlatform:
supports_response,
)
async def _update_entity_states(self, now: datetime) -> None:
async def _async_update_entity_states(self) -> None:
"""Update the states of all the polling entities.
To protect from flooding the executor, we will update async entities

View File

@ -115,10 +115,7 @@ async def test_setup_does_discovery(
assert ("platform_test", {}, {"msg": "discovery_info"}) == mock_setup.call_args[0]
@patch("homeassistant.helpers.entity_platform.async_track_time_interval")
async def test_set_scan_interval_via_config(
mock_track: Mock, hass: HomeAssistant
) -> None:
async def test_set_scan_interval_via_config(hass: HomeAssistant) -> None:
"""Test the setting of the scan interval via configuration."""
def platform_setup(
@ -134,13 +131,14 @@ async def test_set_scan_interval_via_config(
component = EntityComponent(_LOGGER, DOMAIN, hass)
component.setup(
{DOMAIN: {"platform": "platform", "scan_interval": timedelta(seconds=30)}}
)
with patch.object(hass.loop, "call_later") as mock_track:
component.setup(
{DOMAIN: {"platform": "platform", "scan_interval": timedelta(seconds=30)}}
)
await hass.async_block_till_done()
await hass.async_block_till_done()
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
assert mock_track.call_args[0][0] == 30.0
async def test_set_entity_namespace_via_config(hass: HomeAssistant) -> None:

View File

@ -120,7 +120,7 @@ async def test_polling_disabled_by_config_entry(hass: HomeAssistant) -> None:
poll_ent = MockEntity(should_poll=True)
await entity_platform.async_add_entities([poll_ent])
assert entity_platform._async_unsub_polling is None
assert entity_platform._async_polling_timer is None
async def test_polling_updates_entities_with_exception(hass: HomeAssistant) -> None:
@ -213,10 +213,7 @@ async def test_update_state_adds_entities_with_update_before_add_false(
assert not ent.update.called
@patch("homeassistant.helpers.entity_platform.async_track_time_interval")
async def test_set_scan_interval_via_platform(
mock_track: Mock, hass: HomeAssistant
) -> None:
async def test_set_scan_interval_via_platform(hass: HomeAssistant) -> None:
"""Test the setting of the scan interval via platform."""
def platform_setup(
@ -235,11 +232,12 @@ async def test_set_scan_interval_via_platform(
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_setup({DOMAIN: {"platform": "platform"}})
with patch.object(hass.loop, "call_later") as mock_track:
await component.async_setup({DOMAIN: {"platform": "platform"}})
await hass.async_block_till_done()
await hass.async_block_till_done()
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
assert mock_track.call_args[0][0] == 30.0
async def test_adding_entities_with_generator_and_thread_callback(
@ -505,7 +503,7 @@ async def test_parallel_updates_async_platform_updates_in_parallel(
assert handle._update_in_sequence is False
await handle._update_entity_states(dt_util.utcnow())
await handle._async_update_entity_states()
assert peak_update_count > 1
@ -555,7 +553,7 @@ async def test_parallel_updates_sync_platform_updates_in_sequence(
assert handle._update_in_sequence is True
await handle._update_entity_states(dt_util.utcnow())
await handle._async_update_entity_states()
assert peak_update_count == 1
@ -1017,7 +1015,7 @@ async def test_stop_shutdown_cancels_retry_setup_and_interval_listener(
ent_platform.async_shutdown()
assert len(mock_call_later.return_value.mock_calls) == 1
assert ent_platform._async_unsub_polling is None
assert ent_platform._async_polling_timer is None
assert ent_platform._async_cancel_retry_setup is None