Fix race in TimestampDataUpdateCoordinator (#115542)

* Fix race in TimestampDataUpdateCoordinator

The last_update_success_time value was being set after the listeners
were fired which could lead to a loop because the listener may
re-trigger an update because it thinks the data is stale

* coverage

* docstring
This commit is contained in:
J. Nick Koston 2024-04-13 10:35:07 -10:00 committed by GitHub
parent 08e2b655be
commit edd75a9d5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 16 deletions

View File

@ -401,6 +401,8 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
if not auth_failed and self._listeners and not self.hass.is_stopping: if not auth_failed and self._listeners and not self.hass.is_stopping:
self._schedule_refresh() self._schedule_refresh()
self._async_refresh_finished()
if not self.last_update_success and not previous_update_success: if not self.last_update_success and not previous_update_success:
return return
@ -411,6 +413,15 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
): ):
self.async_update_listeners() self.async_update_listeners()
@callback
def _async_refresh_finished(self) -> None:
"""Handle when a refresh has finished.
Called when refresh is finished before listeners are updated.
To be overridden by subclasses.
"""
@callback @callback
def async_set_update_error(self, err: Exception) -> None: def async_set_update_error(self, err: Exception) -> None:
"""Manually set an error, log the message and notify listeners.""" """Manually set an error, log the message and notify listeners."""
@ -444,20 +455,9 @@ class TimestampDataUpdateCoordinator(DataUpdateCoordinator[_DataT]):
last_update_success_time: datetime | None = None last_update_success_time: datetime | None = None
async def _async_refresh( @callback
self, def _async_refresh_finished(self) -> None:
log_failures: bool = True, """Handle when a refresh has finished."""
raise_on_auth_failed: bool = False,
scheduled: bool = False,
raise_on_entry_error: bool = False,
) -> None:
"""Refresh data."""
await super()._async_refresh(
log_failures,
raise_on_auth_failed,
scheduled,
raise_on_entry_error,
)
if self.last_update_success: if self.last_update_success:
self.last_update_success_time = utcnow() self.last_update_success_time = utcnow()

View File

@ -1,6 +1,6 @@
"""Tests for the update coordinator.""" """Tests for the update coordinator."""
from datetime import timedelta from datetime import datetime, timedelta
import logging import logging
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import urllib.error import urllib.error
@ -12,7 +12,7 @@ import requests
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, HomeAssistant from homeassistant.core import CoreState, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import update_coordinator from homeassistant.helpers import update_coordinator
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -715,3 +715,35 @@ async def test_always_callback_when_always_update_is_true(
update_callback.reset_mock() update_callback.reset_mock()
remove_callbacks() remove_callbacks()
async def test_timestamp_date_update_coordinator(hass: HomeAssistant) -> None:
"""Test last_update_success_time is set before calling listeners."""
last_update_success_times: list[datetime | None] = []
async def refresh() -> int:
return 1
crd = update_coordinator.TimestampDataUpdateCoordinator[int](
hass,
_LOGGER,
name="test",
update_method=refresh,
update_interval=timedelta(seconds=10),
)
@callback
def listener():
last_update_success_times.append(crd.last_update_success_time)
unsub = crd.async_add_listener(listener)
await crd.async_refresh()
assert len(last_update_success_times) == 1
# Ensure the time is set before the listener is called
assert last_update_success_times != [None]
unsub()
await crd.async_refresh()
assert len(last_update_success_times) == 1