Keep track of a context for each listener (#72702)

* Remove async_remove_listener

This avoids the ambuigity as to what happens if same callback is added multiple times.

* Keep track of a context for each listener

This allow a update coordinator to adapt what data to request on update from the backing service based on which entities are enabled.

* Clone list before calling callbacks

The callbacks can end up unregistering and modifying the dict while iterating.

* Only yield actual values

* Add a test for update context

* Factor out iteration of _listeners to helper

* Verify context is passed to coordinator

* Switch to Any as type instead of object

* Remove function which use was dropped earliers

The use was removed in 8bee25c938a123f0da7569b4e2753598d478b900
This commit is contained in:
Joakim Plate 2022-06-03 13:55:57 +02:00 committed by GitHub
parent a28fa5377a
commit 8910d265d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 95 additions and 115 deletions

View File

@ -131,4 +131,4 @@ class BMWButton(BMWBaseEntity, ButtonEntity):
# Always update HA states after a button was executed. # Always update HA states after a button was executed.
# BMW remote services that change the vehicle's state update the local object # BMW remote services that change the vehicle's state update the local object
# when executing the service, so only the HA state machine needs further updates. # when executing the service, so only the HA state machine needs further updates.
self.coordinator.notify_listeners() self.coordinator.async_update_listeners()

View File

@ -74,8 +74,3 @@ class BMWDataUpdateCoordinator(DataUpdateCoordinator):
if not refresh_token: if not refresh_token:
data.pop(CONF_REFRESH_TOKEN) data.pop(CONF_REFRESH_TOKEN)
self.hass.config_entries.async_update_entry(self._entry, data=data) self.hass.config_entries.async_update_entry(self._entry, data=data)
def notify_listeners(self) -> None:
"""Notify all listeners to refresh HA state machine."""
for update_callback in self._listeners:
update_callback()

View File

@ -74,12 +74,12 @@ def modernforms_exception_handler(func):
async def handler(self, *args, **kwargs): async def handler(self, *args, **kwargs):
try: try:
await func(self, *args, **kwargs) await func(self, *args, **kwargs)
self.coordinator.update_listeners() self.coordinator.async_update_listeners()
except ModernFormsConnectionError as error: except ModernFormsConnectionError as error:
_LOGGER.error("Error communicating with API: %s", error) _LOGGER.error("Error communicating with API: %s", error)
self.coordinator.last_update_success = False self.coordinator.last_update_success = False
self.coordinator.update_listeners() self.coordinator.async_update_listeners()
except ModernFormsError as error: except ModernFormsError as error:
_LOGGER.error("Invalid response from API: %s", error) _LOGGER.error("Invalid response from API: %s", error)
@ -108,11 +108,6 @@ class ModernFormsDataUpdateCoordinator(DataUpdateCoordinator[ModernFormsDeviceSt
update_interval=SCAN_INTERVAL, update_interval=SCAN_INTERVAL,
) )
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
async def _async_update_data(self) -> ModernFormsDevice: async def _async_update_data(self) -> ModernFormsDevice:
"""Fetch data from Modern Forms.""" """Fetch data from Modern Forms."""
try: try:

View File

@ -83,8 +83,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
async def async_set_cooling(self, enabled: bool) -> None: async def async_set_cooling(self, enabled: bool) -> None:
"""Enable or disable cooling mode.""" """Enable or disable cooling mode."""
await self.base.set_cooling(enabled) await self.base.set_cooling(enabled)
for update_callback in self._listeners: self.async_update_listeners()
update_callback()
async def async_set_target_temperature( async def async_set_target_temperature(
self, heat_area_id: str, target_temperature: float self, heat_area_id: str, target_temperature: float
@ -117,8 +116,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
"Failed to set target temperature, communication error with alpha2 base" "Failed to set target temperature, communication error with alpha2 base"
) from http_err ) from http_err
self.data["heat_areas"][heat_area_id].update(update_data) self.data["heat_areas"][heat_area_id].update(update_data)
for update_callback in self._listeners: self.async_update_listeners()
update_callback()
async def async_set_heat_area_mode( async def async_set_heat_area_mode(
self, heat_area_id: str, heat_area_mode: int self, heat_area_id: str, heat_area_mode: int
@ -161,5 +159,5 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[ self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[
"heat_areas" "heat_areas"
][heat_area_id]["T_HEAT_NIGHT"] ][heat_area_id]["T_HEAT_NIGHT"]
for update_callback in self._listeners:
update_callback() self.async_update_listeners()

View File

@ -19,14 +19,7 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
Platform, Platform,
) )
from homeassistant.core import ( from homeassistant.core import Context, HassJob, HomeAssistant, callback
CALLBACK_TYPE,
Context,
Event,
HassJob,
HomeAssistant,
callback,
)
from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
@ -121,12 +114,7 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
self.options = options self.options = options
self._notify_future: asyncio.Task | None = None self._notify_future: asyncio.Task | None = None
@callback self.turn_on = PluggableAction(self.async_update_listeners)
def _update_listeners():
for update_callback in self._listeners:
update_callback()
self.turn_on = PluggableAction(_update_listeners)
super().__init__( super().__init__(
hass, hass,
@ -193,15 +181,9 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
self._notify_future = asyncio.create_task(self._notify_task()) self._notify_future = asyncio.create_task(self._notify_task())
@callback @callback
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None: def _unschedule_refresh(self) -> None:
"""Remove data update.""" """Remove data update."""
super().async_remove_listener(update_callback) super()._unschedule_refresh()
if not self._listeners:
self._async_notify_stop()
@callback
def _async_stop_refresh(self, event: Event) -> None:
super()._async_stop_refresh(event)
self._async_notify_stop() self._async_notify_stop()
@callback @callback

View File

@ -75,11 +75,6 @@ class SystemBridgeDataUpdateCoordinator(
hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30) hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30)
) )
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
async def async_get_data( async def async_get_data(
self, self,
modules: list[str], modules: list[str],
@ -113,7 +108,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub() self.unsub()
self.unsub = None self.unsub = None
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
except (ConnectionClosedException, ConnectionResetError) as exception: except (ConnectionClosedException, ConnectionResetError) as exception:
self.logger.info( self.logger.info(
"Websocket connection closed for %s. Will retry: %s", "Websocket connection closed for %s. Will retry: %s",
@ -124,7 +119,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub() self.unsub()
self.unsub = None self.unsub = None
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
except ConnectionErrorException as exception: except ConnectionErrorException as exception:
self.logger.warning( self.logger.warning(
"Connection error occurred for %s. Will retry: %s", "Connection error occurred for %s. Will retry: %s",
@ -135,7 +130,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub() self.unsub()
self.unsub = None self.unsub = None
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
async def _setup_websocket(self) -> None: async def _setup_websocket(self) -> None:
"""Use WebSocket for updates.""" """Use WebSocket for updates."""
@ -151,7 +146,7 @@ class SystemBridgeDataUpdateCoordinator(
self.unsub() self.unsub()
self.unsub = None self.unsub = None
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
except ConnectionErrorException as exception: except ConnectionErrorException as exception:
self.logger.warning( self.logger.warning(
"Connection error occurred for %s. Will retry: %s", "Connection error occurred for %s. Will retry: %s",
@ -159,7 +154,7 @@ class SystemBridgeDataUpdateCoordinator(
exception, exception,
) )
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
except asyncio.TimeoutError as exception: except asyncio.TimeoutError as exception:
self.logger.warning( self.logger.warning(
"Timed out waiting for %s. Will retry: %s", "Timed out waiting for %s. Will retry: %s",
@ -167,11 +162,11 @@ class SystemBridgeDataUpdateCoordinator(
exception, exception,
) )
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
self.hass.async_create_task(self._listen_for_data()) self.hass.async_create_task(self._listen_for_data())
self.last_update_success = True self.last_update_success = True
self.update_listeners() self.async_update_listeners()
async def close_websocket(_) -> None: async def close_websocket(_) -> None:
"""Close WebSocket connection.""" """Close WebSocket connection."""

View File

@ -47,11 +47,6 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL
) )
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
async def register_webhook(self, event: Event | None = None) -> None: async def register_webhook(self, event: Event | None = None) -> None:
"""Register a webhook with Toon to get live updates.""" """Register a webhook with Toon to get live updates."""
if CONF_WEBHOOK_ID not in self.entry.data: if CONF_WEBHOOK_ID not in self.entry.data:
@ -128,7 +123,7 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
try: try:
await self.toon.update(data["updateDataSet"]) await self.toon.update(data["updateDataSet"])
self.update_listeners() self.async_update_listeners()
except ToonError as err: except ToonError as err:
_LOGGER.error("Could not process data received from Toon webhook - %s", err) _LOGGER.error("Could not process data received from Toon webhook - %s", err)

View File

@ -16,12 +16,12 @@ def toon_exception_handler(func):
async def handler(self, *args, **kwargs): async def handler(self, *args, **kwargs):
try: try:
await func(self, *args, **kwargs) await func(self, *args, **kwargs)
self.coordinator.update_listeners() self.coordinator.async_update_listeners()
except ToonConnectionError as error: except ToonConnectionError as error:
_LOGGER.error("Error communicating with API: %s", error) _LOGGER.error("Error communicating with API: %s", error)
self.coordinator.last_update_success = False self.coordinator.last_update_success = False
self.coordinator.update_listeners() self.coordinator.async_update_listeners()
except ToonError as error: except ToonError as error:
_LOGGER.error("Invalid response from API: %s", error) _LOGGER.error("Invalid response from API: %s", error)

View File

@ -123,12 +123,6 @@ class DeviceCoordinator(DataUpdateCoordinator):
except ActionException as err: except ActionException as err:
raise UpdateFailed("WeMo update failed") from err raise UpdateFailed("WeMo update failed") from err
@callback
def async_update_listeners(self) -> None:
"""Update all listeners."""
for update_callback in self._listeners:
update_callback()
def _device_info(wemo: WeMoDevice) -> DeviceInfo: def _device_info(wemo: WeMoDevice) -> DeviceInfo:
return DeviceInfo( return DeviceInfo(

View File

@ -54,11 +54,6 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
self.data is not None and len(self.data.state.segments) > 1 self.data is not None and len(self.data.state.segments) > 1
) )
def update_listeners(self) -> None:
"""Call update on all listeners."""
for update_callback in self._listeners:
update_callback()
@callback @callback
def _use_websocket(self) -> None: def _use_websocket(self) -> None:
"""Use WebSocket for updates, instead of polling.""" """Use WebSocket for updates, instead of polling."""
@ -81,7 +76,7 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
self.logger.info(err) self.logger.info(err)
except WLEDError as err: except WLEDError as err:
self.last_update_success = False self.last_update_success = False
self.update_listeners() self.async_update_listeners()
self.logger.error(err) self.logger.error(err)
# Ensure we are disconnected # Ensure we are disconnected

View File

@ -15,11 +15,11 @@ def wled_exception_handler(func):
async def handler(self, *args, **kwargs): async def handler(self, *args, **kwargs):
try: try:
await func(self, *args, **kwargs) await func(self, *args, **kwargs)
self.coordinator.update_listeners() self.coordinator.async_update_listeners()
except WLEDConnectionError as error: except WLEDConnectionError as error:
self.coordinator.last_update_success = False self.coordinator.last_update_success = False
self.coordinator.update_listeners() self.coordinator.async_update_listeners()
raise HomeAssistantError("Error communicating with WLED API") from error raise HomeAssistantError("Error communicating with WLED API") from error
except WLEDError as error: except WLEDError as error:

View File

@ -106,7 +106,9 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
self.coordinator.musiccast.register_group_update_callback( self.coordinator.musiccast.register_group_update_callback(
self.update_all_mc_entities self.update_all_mc_entities
) )
self.coordinator.async_add_listener(self.async_schedule_check_client_list) self.async_on_remove(
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Entity being removed from hass.""" """Entity being removed from hass."""
@ -116,7 +118,6 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
self.coordinator.musiccast.remove_group_update_callback( self.coordinator.musiccast.remove_group_update_callback(
self.update_all_mc_entities self.update_all_mc_entities
) )
self.coordinator.async_remove_listener(self.async_schedule_check_client_list)
@property @property
def should_poll(self): def should_poll(self):

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Generator
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from time import monotonic from time import monotonic
@ -13,7 +13,7 @@ import aiohttp
import requests import requests
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -61,7 +61,7 @@ class DataUpdateCoordinator(Generic[_T]):
# when it was already checked during setup. # when it was already checked during setup.
self.data: _T = None # type: ignore[assignment] self.data: _T = None # type: ignore[assignment]
self._listeners: list[CALLBACK_TYPE] = [] self._listeners: dict[CALLBACK_TYPE, tuple[CALLBACK_TYPE, object | None]] = {}
self._job = HassJob(self._handle_refresh_interval) self._job = HassJob(self._handle_refresh_interval)
self._unsub_refresh: CALLBACK_TYPE | None = None self._unsub_refresh: CALLBACK_TYPE | None = None
self._request_refresh_task: asyncio.TimerHandle | None = None self._request_refresh_task: asyncio.TimerHandle | None = None
@ -82,32 +82,46 @@ class DataUpdateCoordinator(Generic[_T]):
self._debounced_refresh = request_refresh_debouncer self._debounced_refresh = request_refresh_debouncer
@callback @callback
def async_add_listener(self, update_callback: CALLBACK_TYPE) -> Callable[[], None]: def async_add_listener(
self, update_callback: CALLBACK_TYPE, context: Any = None
) -> Callable[[], None]:
"""Listen for data updates.""" """Listen for data updates."""
schedule_refresh = not self._listeners schedule_refresh = not self._listeners
self._listeners.append(update_callback) @callback
def remove_listener() -> None:
"""Remove update listener."""
self._listeners.pop(remove_listener)
if not self._listeners:
self._unschedule_refresh()
self._listeners[remove_listener] = (update_callback, context)
# This is the first listener, set up interval. # This is the first listener, set up interval.
if schedule_refresh: if schedule_refresh:
self._schedule_refresh() self._schedule_refresh()
@callback
def remove_listener() -> None:
"""Remove update listener."""
self.async_remove_listener(update_callback)
return remove_listener return remove_listener
@callback @callback
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None: def async_update_listeners(self) -> None:
"""Remove data update.""" """Update all registered listeners."""
self._listeners.remove(update_callback) for update_callback, _ in list(self._listeners.values()):
update_callback()
if not self._listeners and self._unsub_refresh: @callback
def _unschedule_refresh(self) -> None:
"""Unschedule any pending refresh since there is no longer any listeners."""
if self._unsub_refresh:
self._unsub_refresh() self._unsub_refresh()
self._unsub_refresh = None self._unsub_refresh = None
def async_contexts(self) -> Generator[Any, None, None]:
"""Return all registered contexts."""
yield from (
context for _, context in self._listeners.values() if context is not None
)
@callback @callback
def _schedule_refresh(self) -> None: def _schedule_refresh(self) -> None:
"""Schedule a refresh.""" """Schedule a refresh."""
@ -266,8 +280,7 @@ class DataUpdateCoordinator(Generic[_T]):
if not auth_failed and self._listeners and not self.hass.is_stopping: if not auth_failed and self._listeners and not self.hass.is_stopping:
self._schedule_refresh() self._schedule_refresh()
for update_callback in self._listeners: self.async_update_listeners()
update_callback()
@callback @callback
def async_set_updated_data(self, data: _T) -> None: def async_set_updated_data(self, data: _T) -> None:
@ -288,24 +301,18 @@ class DataUpdateCoordinator(Generic[_T]):
if self._listeners: if self._listeners:
self._schedule_refresh() self._schedule_refresh()
for update_callback in self._listeners: self.async_update_listeners()
update_callback()
@callback
def _async_stop_refresh(self, _: Event) -> None:
"""Stop refreshing when Home Assistant is stopping."""
self.update_interval = None
if self._unsub_refresh:
self._unsub_refresh()
self._unsub_refresh = None
class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]): class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
"""A class for entities using DataUpdateCoordinator.""" """A class for entities using DataUpdateCoordinator."""
def __init__(self, coordinator: _DataUpdateCoordinatorT) -> None: def __init__(
self, coordinator: _DataUpdateCoordinatorT, context: Any = None
) -> None:
"""Create the entity with a DataUpdateCoordinator.""" """Create the entity with a DataUpdateCoordinator."""
self.coordinator = coordinator self.coordinator = coordinator
self.coordinator_context = context
@property @property
def should_poll(self) -> bool: def should_poll(self) -> bool:
@ -321,7 +328,9 @@ class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
"""When entity is added to hass.""" """When entity is added to hass."""
await super().async_added_to_hass() await super().async_added_to_hass()
self.async_on_remove( self.async_on_remove(
self.coordinator.async_add_listener(self._handle_coordinator_update) self.coordinator.async_add_listener(
self._handle_coordinator_update, self.coordinator_context
)
) )
@callback @callback

View File

@ -109,11 +109,29 @@ async def test_async_refresh(crd):
await crd.async_refresh() await crd.async_refresh()
assert updates == [2] assert updates == [2]
# Test unsubscribing through method
crd.async_add_listener(update_callback) async def test_update_context(crd: update_coordinator.DataUpdateCoordinator[int]):
crd.async_remove_listener(update_callback) """Test update contexts for the update coordinator."""
await crd.async_refresh() await crd.async_refresh()
assert updates == [2] assert not set(crd.async_contexts())
def update_callback1():
pass
def update_callback2():
pass
unsub1 = crd.async_add_listener(update_callback1, 1)
assert set(crd.async_contexts()) == {1}
unsub2 = crd.async_add_listener(update_callback2, 2)
assert set(crd.async_contexts()) == {1, 2}
unsub1()
assert set(crd.async_contexts()) == {2}
unsub2()
assert not set(crd.async_contexts())
async def test_request_refresh(crd): async def test_request_refresh(crd):
@ -191,7 +209,7 @@ async def test_update_interval(hass, crd):
# Add subscriber # Add subscriber
update_callback = Mock() update_callback = Mock()
crd.async_add_listener(update_callback) unsub = crd.async_add_listener(update_callback)
# Test twice we update with subscriber # Test twice we update with subscriber
async_fire_time_changed(hass, utcnow() + crd.update_interval) async_fire_time_changed(hass, utcnow() + crd.update_interval)
@ -203,7 +221,7 @@ async def test_update_interval(hass, crd):
assert crd.data == 2 assert crd.data == 2
# Test removing listener # Test removing listener
crd.async_remove_listener(update_callback) unsub()
async_fire_time_changed(hass, utcnow() + crd.update_interval) async_fire_time_changed(hass, utcnow() + crd.update_interval)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -222,7 +240,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
# Add subscriber # Add subscriber
update_callback = Mock() update_callback = Mock()
crd.async_add_listener(update_callback) unsub = crd.async_add_listener(update_callback)
# Test twice we don't update with subscriber with no update interval # Test twice we don't update with subscriber with no update interval
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL) async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
@ -234,7 +252,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
assert crd.data is None assert crd.data is None
# Test removing listener # Test removing listener
crd.async_remove_listener(update_callback) unsub()
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL) async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -253,9 +271,10 @@ async def test_refresh_recover(crd, caplog):
assert "Fetching test data recovered" in caplog.text assert "Fetching test data recovered" in caplog.text
async def test_coordinator_entity(crd): async def test_coordinator_entity(crd: update_coordinator.DataUpdateCoordinator[int]):
"""Test the CoordinatorEntity class.""" """Test the CoordinatorEntity class."""
entity = update_coordinator.CoordinatorEntity(crd) context = object()
entity = update_coordinator.CoordinatorEntity(crd, context)
assert entity.should_poll is False assert entity.should_poll is False
@ -278,6 +297,8 @@ async def test_coordinator_entity(crd):
await entity.async_update() await entity.async_update()
assert entity.available is False assert entity.available is False
assert list(crd.async_contexts()) == [context]
async def test_async_set_updated_data(crd): async def test_async_set_updated_data(crd):
"""Test async_set_updated_data for update coordinator.""" """Test async_set_updated_data for update coordinator."""