Refactor entity_platform polling to avoid double time fetch (#116877)

* Refactor entity_platform polling to avoid double time fetch

Replace async_track_time_interval with loop.call_later
to avoid the useless time fetch every time the listener
fired since we always throw it away

* fix test
This commit is contained in:
J. Nick Koston 2024-05-05 15:28:01 -05:00 committed by GitHub
parent 76cd498c44
commit b41b1bb998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 38 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable from collections.abc import Awaitable, Callable, Coroutine, Iterable
from contextvars import ContextVar from contextvars import ContextVar
from datetime import datetime, timedelta from datetime import timedelta
from functools import partial from functools import partial
from logging import Logger, getLogger from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol
@ -43,7 +43,7 @@ from . import (
translation, translation,
) )
from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider
from .event import async_call_later, async_track_time_interval from .event import async_call_later
from .issue_registry import IssueSeverity, async_create_issue from .issue_registry import IssueSeverity, async_create_issue
from .typing import UNDEFINED, ConfigType, DiscoveryInfoType from .typing import UNDEFINED, ConfigType, DiscoveryInfoType
@ -125,6 +125,7 @@ class EntityPlatform:
self.platform_name = platform_name self.platform_name = platform_name
self.platform = platform self.platform = platform
self.scan_interval = scan_interval self.scan_interval = scan_interval
self.scan_interval_seconds = scan_interval.total_seconds()
self.entity_namespace = entity_namespace self.entity_namespace = entity_namespace
self.config_entry: config_entries.ConfigEntry | None = None self.config_entry: config_entries.ConfigEntry | None = None
# Storage for entities for this specific platform only # Storage for entities for this specific platform only
@ -138,7 +139,7 @@ class EntityPlatform:
# Stop tracking tasks after setup is completed # Stop tracking tasks after setup is completed
self._setup_complete = False self._setup_complete = False
# Method to cancel the state change listener # Method to cancel the state change listener
self._async_unsub_polling: CALLBACK_TYPE | None = None self._async_polling_timer: asyncio.TimerHandle | None = None
# Method to cancel the retry of setup # Method to cancel the retry of setup
self._async_cancel_retry_setup: CALLBACK_TYPE | None = None self._async_cancel_retry_setup: CALLBACK_TYPE | None = None
self._process_updates: asyncio.Lock | None = None self._process_updates: asyncio.Lock | None = None
@ -630,7 +631,7 @@ class EntityPlatform:
if ( if (
(self.config_entry and self.config_entry.pref_disable_polling) (self.config_entry and self.config_entry.pref_disable_polling)
or self._async_unsub_polling is not None or self._async_polling_timer is not None
or not any( or not any(
# Entity may have failed to add or called `add_to_platform_abort` # Entity may have failed to add or called `add_to_platform_abort`
# so we check if the entity is in self.entities before # so we check if the entity is in self.entities before
@ -644,26 +645,28 @@ class EntityPlatform:
): ):
return return
self._async_unsub_polling = async_track_time_interval( self._async_polling_timer = self.hass.loop.call_later(
self.hass, self.scan_interval_seconds,
self._async_handle_interval_callback, self._async_handle_interval_callback,
self.scan_interval,
name=f"EntityPlatform poll {self.domain}.{self.platform_name}",
) )
@callback @callback
def _async_handle_interval_callback(self, now: datetime) -> None: def _async_handle_interval_callback(self) -> None:
"""Update all the entity states in a single platform.""" """Update all the entity states in a single platform."""
self._async_polling_timer = self.hass.loop.call_later(
self.scan_interval_seconds,
self._async_handle_interval_callback,
)
if self.config_entry: if self.config_entry:
self.config_entry.async_create_background_task( self.config_entry.async_create_background_task(
self.hass, self.hass,
self._update_entity_states(now), self._async_update_entity_states(),
name=f"EntityPlatform poll {self.domain}.{self.platform_name}", name=f"EntityPlatform poll {self.domain}.{self.platform_name}",
eager_start=True, eager_start=True,
) )
else: else:
self.hass.async_create_background_task( self.hass.async_create_background_task(
self._update_entity_states(now), self._async_update_entity_states(),
name=f"EntityPlatform poll {self.domain}.{self.platform_name}", name=f"EntityPlatform poll {self.domain}.{self.platform_name}",
eager_start=True, eager_start=True,
) )
@ -919,9 +922,9 @@ class EntityPlatform:
@callback @callback
def async_unsub_polling(self) -> None: def async_unsub_polling(self) -> None:
"""Stop polling.""" """Stop polling."""
if self._async_unsub_polling is not None: if self._async_polling_timer is not None:
self._async_unsub_polling() self._async_polling_timer.cancel()
self._async_unsub_polling = None self._async_polling_timer = None
@callback @callback
def async_prepare(self) -> None: def async_prepare(self) -> None:
@ -943,11 +946,10 @@ class EntityPlatform:
await self.entities[entity_id].async_remove() await self.entities[entity_id].async_remove()
# Clean up polling job if no longer needed # Clean up polling job if no longer needed
if self._async_unsub_polling is not None and not any( if self._async_polling_timer is not None and not any(
entity.should_poll for entity in self.entities.values() entity.should_poll for entity in self.entities.values()
): ):
self._async_unsub_polling() self.async_unsub_polling()
self._async_unsub_polling = None
async def async_extract_from_service( async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True self, service_call: ServiceCall, expand_group: bool = True
@ -998,7 +1000,7 @@ class EntityPlatform:
supports_response, supports_response,
) )
async def _update_entity_states(self, now: datetime) -> None: async def _async_update_entity_states(self) -> None:
"""Update the states of all the polling entities. """Update the states of all the polling entities.
To protect from flooding the executor, we will update async entities To protect from flooding the executor, we will update async entities

View File

@ -115,10 +115,7 @@ async def test_setup_does_discovery(
assert ("platform_test", {}, {"msg": "discovery_info"}) == mock_setup.call_args[0] assert ("platform_test", {}, {"msg": "discovery_info"}) == mock_setup.call_args[0]
@patch("homeassistant.helpers.entity_platform.async_track_time_interval") async def test_set_scan_interval_via_config(hass: HomeAssistant) -> None:
async def test_set_scan_interval_via_config(
mock_track: Mock, hass: HomeAssistant
) -> None:
"""Test the setting of the scan interval via configuration.""" """Test the setting of the scan interval via configuration."""
def platform_setup( def platform_setup(
@ -134,13 +131,14 @@ async def test_set_scan_interval_via_config(
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
with patch.object(hass.loop, "call_later") as mock_track:
component.setup( component.setup(
{DOMAIN: {"platform": "platform", "scan_interval": timedelta(seconds=30)}} {DOMAIN: {"platform": "platform", "scan_interval": timedelta(seconds=30)}}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_track.called assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2] assert mock_track.call_args[0][0] == 30.0
async def test_set_entity_namespace_via_config(hass: HomeAssistant) -> None: async def test_set_entity_namespace_via_config(hass: HomeAssistant) -> None:

View File

@ -120,7 +120,7 @@ async def test_polling_disabled_by_config_entry(hass: HomeAssistant) -> None:
poll_ent = MockEntity(should_poll=True) poll_ent = MockEntity(should_poll=True)
await entity_platform.async_add_entities([poll_ent]) await entity_platform.async_add_entities([poll_ent])
assert entity_platform._async_unsub_polling is None assert entity_platform._async_polling_timer is None
async def test_polling_updates_entities_with_exception(hass: HomeAssistant) -> None: async def test_polling_updates_entities_with_exception(hass: HomeAssistant) -> None:
@ -213,10 +213,7 @@ async def test_update_state_adds_entities_with_update_before_add_false(
assert not ent.update.called assert not ent.update.called
@patch("homeassistant.helpers.entity_platform.async_track_time_interval") async def test_set_scan_interval_via_platform(hass: HomeAssistant) -> None:
async def test_set_scan_interval_via_platform(
mock_track: Mock, hass: HomeAssistant
) -> None:
"""Test the setting of the scan interval via platform.""" """Test the setting of the scan interval via platform."""
def platform_setup( def platform_setup(
@ -235,11 +232,12 @@ async def test_set_scan_interval_via_platform(
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
with patch.object(hass.loop, "call_later") as mock_track:
await component.async_setup({DOMAIN: {"platform": "platform"}}) await component.async_setup({DOMAIN: {"platform": "platform"}})
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_track.called assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2] assert mock_track.call_args[0][0] == 30.0
async def test_adding_entities_with_generator_and_thread_callback( async def test_adding_entities_with_generator_and_thread_callback(
@ -505,7 +503,7 @@ async def test_parallel_updates_async_platform_updates_in_parallel(
assert handle._update_in_sequence is False assert handle._update_in_sequence is False
await handle._update_entity_states(dt_util.utcnow()) await handle._async_update_entity_states()
assert peak_update_count > 1 assert peak_update_count > 1
@ -555,7 +553,7 @@ async def test_parallel_updates_sync_platform_updates_in_sequence(
assert handle._update_in_sequence is True assert handle._update_in_sequence is True
await handle._update_entity_states(dt_util.utcnow()) await handle._async_update_entity_states()
assert peak_update_count == 1 assert peak_update_count == 1
@ -1017,7 +1015,7 @@ async def test_stop_shutdown_cancels_retry_setup_and_interval_listener(
ent_platform.async_shutdown() ent_platform.async_shutdown()
assert len(mock_call_later.return_value.mock_calls) == 1 assert len(mock_call_later.return_value.mock_calls) == 1
assert ent_platform._async_unsub_polling is None assert ent_platform._async_polling_timer is None
assert ent_platform._async_cancel_retry_setup is None assert ent_platform._async_cancel_retry_setup is None