mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Keep track of a context for each listener (#72702)
* Remove async_remove_listener This avoids the ambuigity as to what happens if same callback is added multiple times. * Keep track of a context for each listener This allow a update coordinator to adapt what data to request on update from the backing service based on which entities are enabled. * Clone list before calling callbacks The callbacks can end up unregistering and modifying the dict while iterating. * Only yield actual values * Add a test for update context * Factor out iteration of _listeners to helper * Verify context is passed to coordinator * Switch to Any as type instead of object * Remove function which use was dropped earliers The use was removed in 8bee25c938a123f0da7569b4e2753598d478b900
This commit is contained in:
parent
a28fa5377a
commit
8910d265d6
@ -131,4 +131,4 @@ class BMWButton(BMWBaseEntity, ButtonEntity):
|
|||||||
# Always update HA states after a button was executed.
|
# Always update HA states after a button was executed.
|
||||||
# BMW remote services that change the vehicle's state update the local object
|
# BMW remote services that change the vehicle's state update the local object
|
||||||
# when executing the service, so only the HA state machine needs further updates.
|
# when executing the service, so only the HA state machine needs further updates.
|
||||||
self.coordinator.notify_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
|
@ -74,8 +74,3 @@ class BMWDataUpdateCoordinator(DataUpdateCoordinator):
|
|||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
data.pop(CONF_REFRESH_TOKEN)
|
data.pop(CONF_REFRESH_TOKEN)
|
||||||
self.hass.config_entries.async_update_entry(self._entry, data=data)
|
self.hass.config_entries.async_update_entry(self._entry, data=data)
|
||||||
|
|
||||||
def notify_listeners(self) -> None:
|
|
||||||
"""Notify all listeners to refresh HA state machine."""
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
@ -74,12 +74,12 @@ def modernforms_exception_handler(func):
|
|||||||
async def handler(self, *args, **kwargs):
|
async def handler(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
await func(self, *args, **kwargs)
|
await func(self, *args, **kwargs)
|
||||||
self.coordinator.update_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
|
|
||||||
except ModernFormsConnectionError as error:
|
except ModernFormsConnectionError as error:
|
||||||
_LOGGER.error("Error communicating with API: %s", error)
|
_LOGGER.error("Error communicating with API: %s", error)
|
||||||
self.coordinator.last_update_success = False
|
self.coordinator.last_update_success = False
|
||||||
self.coordinator.update_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
|
|
||||||
except ModernFormsError as error:
|
except ModernFormsError as error:
|
||||||
_LOGGER.error("Invalid response from API: %s", error)
|
_LOGGER.error("Invalid response from API: %s", error)
|
||||||
@ -108,11 +108,6 @@ class ModernFormsDataUpdateCoordinator(DataUpdateCoordinator[ModernFormsDeviceSt
|
|||||||
update_interval=SCAN_INTERVAL,
|
update_interval=SCAN_INTERVAL,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_listeners(self) -> None:
|
|
||||||
"""Call update on all listeners."""
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
|
||||||
async def _async_update_data(self) -> ModernFormsDevice:
|
async def _async_update_data(self) -> ModernFormsDevice:
|
||||||
"""Fetch data from Modern Forms."""
|
"""Fetch data from Modern Forms."""
|
||||||
try:
|
try:
|
||||||
|
@ -83,8 +83,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
|
|||||||
async def async_set_cooling(self, enabled: bool) -> None:
|
async def async_set_cooling(self, enabled: bool) -> None:
|
||||||
"""Enable or disable cooling mode."""
|
"""Enable or disable cooling mode."""
|
||||||
await self.base.set_cooling(enabled)
|
await self.base.set_cooling(enabled)
|
||||||
for update_callback in self._listeners:
|
self.async_update_listeners()
|
||||||
update_callback()
|
|
||||||
|
|
||||||
async def async_set_target_temperature(
|
async def async_set_target_temperature(
|
||||||
self, heat_area_id: str, target_temperature: float
|
self, heat_area_id: str, target_temperature: float
|
||||||
@ -117,8 +116,7 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
|
|||||||
"Failed to set target temperature, communication error with alpha2 base"
|
"Failed to set target temperature, communication error with alpha2 base"
|
||||||
) from http_err
|
) from http_err
|
||||||
self.data["heat_areas"][heat_area_id].update(update_data)
|
self.data["heat_areas"][heat_area_id].update(update_data)
|
||||||
for update_callback in self._listeners:
|
self.async_update_listeners()
|
||||||
update_callback()
|
|
||||||
|
|
||||||
async def async_set_heat_area_mode(
|
async def async_set_heat_area_mode(
|
||||||
self, heat_area_id: str, heat_area_mode: int
|
self, heat_area_id: str, heat_area_mode: int
|
||||||
@ -161,5 +159,5 @@ class Alpha2BaseCoordinator(DataUpdateCoordinator[dict[str, dict]]):
|
|||||||
self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[
|
self.data["heat_areas"][heat_area_id]["T_TARGET"] = self.data[
|
||||||
"heat_areas"
|
"heat_areas"
|
||||||
][heat_area_id]["T_HEAT_NIGHT"]
|
][heat_area_id]["T_HEAT_NIGHT"]
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
self.async_update_listeners()
|
||||||
|
@ -19,14 +19,7 @@ from homeassistant.const import (
|
|||||||
CONF_USERNAME,
|
CONF_USERNAME,
|
||||||
Platform,
|
Platform,
|
||||||
)
|
)
|
||||||
from homeassistant.core import (
|
from homeassistant.core import Context, HassJob, HomeAssistant, callback
|
||||||
CALLBACK_TYPE,
|
|
||||||
Context,
|
|
||||||
Event,
|
|
||||||
HassJob,
|
|
||||||
HomeAssistant,
|
|
||||||
callback,
|
|
||||||
)
|
|
||||||
from homeassistant.helpers.debounce import Debouncer
|
from homeassistant.helpers.debounce import Debouncer
|
||||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
||||||
|
|
||||||
@ -121,12 +114,7 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
|
|||||||
self.options = options
|
self.options = options
|
||||||
self._notify_future: asyncio.Task | None = None
|
self._notify_future: asyncio.Task | None = None
|
||||||
|
|
||||||
@callback
|
self.turn_on = PluggableAction(self.async_update_listeners)
|
||||||
def _update_listeners():
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
|
||||||
self.turn_on = PluggableAction(_update_listeners)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hass,
|
hass,
|
||||||
@ -193,15 +181,9 @@ class PhilipsTVDataUpdateCoordinator(DataUpdateCoordinator[None]):
|
|||||||
self._notify_future = asyncio.create_task(self._notify_task())
|
self._notify_future = asyncio.create_task(self._notify_task())
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None:
|
def _unschedule_refresh(self) -> None:
|
||||||
"""Remove data update."""
|
"""Remove data update."""
|
||||||
super().async_remove_listener(update_callback)
|
super()._unschedule_refresh()
|
||||||
if not self._listeners:
|
|
||||||
self._async_notify_stop()
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_stop_refresh(self, event: Event) -> None:
|
|
||||||
super()._async_stop_refresh(event)
|
|
||||||
self._async_notify_stop()
|
self._async_notify_stop()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -75,11 +75,6 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30)
|
hass, LOGGER, name=DOMAIN, update_interval=timedelta(seconds=30)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_listeners(self) -> None:
|
|
||||||
"""Call update on all listeners."""
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
|
||||||
async def async_get_data(
|
async def async_get_data(
|
||||||
self,
|
self,
|
||||||
modules: list[str],
|
modules: list[str],
|
||||||
@ -113,7 +108,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
self.unsub()
|
self.unsub()
|
||||||
self.unsub = None
|
self.unsub = None
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
except (ConnectionClosedException, ConnectionResetError) as exception:
|
except (ConnectionClosedException, ConnectionResetError) as exception:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Websocket connection closed for %s. Will retry: %s",
|
"Websocket connection closed for %s. Will retry: %s",
|
||||||
@ -124,7 +119,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
self.unsub()
|
self.unsub()
|
||||||
self.unsub = None
|
self.unsub = None
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
except ConnectionErrorException as exception:
|
except ConnectionErrorException as exception:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Connection error occurred for %s. Will retry: %s",
|
"Connection error occurred for %s. Will retry: %s",
|
||||||
@ -135,7 +130,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
self.unsub()
|
self.unsub()
|
||||||
self.unsub = None
|
self.unsub = None
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
|
|
||||||
async def _setup_websocket(self) -> None:
|
async def _setup_websocket(self) -> None:
|
||||||
"""Use WebSocket for updates."""
|
"""Use WebSocket for updates."""
|
||||||
@ -151,7 +146,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
self.unsub()
|
self.unsub()
|
||||||
self.unsub = None
|
self.unsub = None
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
except ConnectionErrorException as exception:
|
except ConnectionErrorException as exception:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Connection error occurred for %s. Will retry: %s",
|
"Connection error occurred for %s. Will retry: %s",
|
||||||
@ -159,7 +154,7 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
exception,
|
exception,
|
||||||
)
|
)
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
except asyncio.TimeoutError as exception:
|
except asyncio.TimeoutError as exception:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Timed out waiting for %s. Will retry: %s",
|
"Timed out waiting for %s. Will retry: %s",
|
||||||
@ -167,11 +162,11 @@ class SystemBridgeDataUpdateCoordinator(
|
|||||||
exception,
|
exception,
|
||||||
)
|
)
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
|
|
||||||
self.hass.async_create_task(self._listen_for_data())
|
self.hass.async_create_task(self._listen_for_data())
|
||||||
self.last_update_success = True
|
self.last_update_success = True
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
|
|
||||||
async def close_websocket(_) -> None:
|
async def close_websocket(_) -> None:
|
||||||
"""Close WebSocket connection."""
|
"""Close WebSocket connection."""
|
||||||
|
@ -47,11 +47,6 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
|
|||||||
hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL
|
hass, _LOGGER, name=DOMAIN, update_interval=DEFAULT_SCAN_INTERVAL
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_listeners(self) -> None:
|
|
||||||
"""Call update on all listeners."""
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
|
||||||
async def register_webhook(self, event: Event | None = None) -> None:
|
async def register_webhook(self, event: Event | None = None) -> None:
|
||||||
"""Register a webhook with Toon to get live updates."""
|
"""Register a webhook with Toon to get live updates."""
|
||||||
if CONF_WEBHOOK_ID not in self.entry.data:
|
if CONF_WEBHOOK_ID not in self.entry.data:
|
||||||
@ -128,7 +123,7 @@ class ToonDataUpdateCoordinator(DataUpdateCoordinator[Status]):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self.toon.update(data["updateDataSet"])
|
await self.toon.update(data["updateDataSet"])
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
except ToonError as err:
|
except ToonError as err:
|
||||||
_LOGGER.error("Could not process data received from Toon webhook - %s", err)
|
_LOGGER.error("Could not process data received from Toon webhook - %s", err)
|
||||||
|
|
||||||
|
@ -16,12 +16,12 @@ def toon_exception_handler(func):
|
|||||||
async def handler(self, *args, **kwargs):
|
async def handler(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
await func(self, *args, **kwargs)
|
await func(self, *args, **kwargs)
|
||||||
self.coordinator.update_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
|
|
||||||
except ToonConnectionError as error:
|
except ToonConnectionError as error:
|
||||||
_LOGGER.error("Error communicating with API: %s", error)
|
_LOGGER.error("Error communicating with API: %s", error)
|
||||||
self.coordinator.last_update_success = False
|
self.coordinator.last_update_success = False
|
||||||
self.coordinator.update_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
|
|
||||||
except ToonError as error:
|
except ToonError as error:
|
||||||
_LOGGER.error("Invalid response from API: %s", error)
|
_LOGGER.error("Invalid response from API: %s", error)
|
||||||
|
@ -123,12 +123,6 @@ class DeviceCoordinator(DataUpdateCoordinator):
|
|||||||
except ActionException as err:
|
except ActionException as err:
|
||||||
raise UpdateFailed("WeMo update failed") from err
|
raise UpdateFailed("WeMo update failed") from err
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_update_listeners(self) -> None:
|
|
||||||
"""Update all listeners."""
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
|
||||||
|
|
||||||
def _device_info(wemo: WeMoDevice) -> DeviceInfo:
|
def _device_info(wemo: WeMoDevice) -> DeviceInfo:
|
||||||
return DeviceInfo(
|
return DeviceInfo(
|
||||||
|
@ -54,11 +54,6 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
|
|||||||
self.data is not None and len(self.data.state.segments) > 1
|
self.data is not None and len(self.data.state.segments) > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_listeners(self) -> None:
|
|
||||||
"""Call update on all listeners."""
|
|
||||||
for update_callback in self._listeners:
|
|
||||||
update_callback()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _use_websocket(self) -> None:
|
def _use_websocket(self) -> None:
|
||||||
"""Use WebSocket for updates, instead of polling."""
|
"""Use WebSocket for updates, instead of polling."""
|
||||||
@ -81,7 +76,7 @@ class WLEDDataUpdateCoordinator(DataUpdateCoordinator[WLEDDevice]):
|
|||||||
self.logger.info(err)
|
self.logger.info(err)
|
||||||
except WLEDError as err:
|
except WLEDError as err:
|
||||||
self.last_update_success = False
|
self.last_update_success = False
|
||||||
self.update_listeners()
|
self.async_update_listeners()
|
||||||
self.logger.error(err)
|
self.logger.error(err)
|
||||||
|
|
||||||
# Ensure we are disconnected
|
# Ensure we are disconnected
|
||||||
|
@ -15,11 +15,11 @@ def wled_exception_handler(func):
|
|||||||
async def handler(self, *args, **kwargs):
|
async def handler(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
await func(self, *args, **kwargs)
|
await func(self, *args, **kwargs)
|
||||||
self.coordinator.update_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
|
|
||||||
except WLEDConnectionError as error:
|
except WLEDConnectionError as error:
|
||||||
self.coordinator.last_update_success = False
|
self.coordinator.last_update_success = False
|
||||||
self.coordinator.update_listeners()
|
self.coordinator.async_update_listeners()
|
||||||
raise HomeAssistantError("Error communicating with WLED API") from error
|
raise HomeAssistantError("Error communicating with WLED API") from error
|
||||||
|
|
||||||
except WLEDError as error:
|
except WLEDError as error:
|
||||||
|
@ -106,7 +106,9 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
|
|||||||
self.coordinator.musiccast.register_group_update_callback(
|
self.coordinator.musiccast.register_group_update_callback(
|
||||||
self.update_all_mc_entities
|
self.update_all_mc_entities
|
||||||
)
|
)
|
||||||
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
|
self.async_on_remove(
|
||||||
|
self.coordinator.async_add_listener(self.async_schedule_check_client_list)
|
||||||
|
)
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self):
|
async def async_will_remove_from_hass(self):
|
||||||
"""Entity being removed from hass."""
|
"""Entity being removed from hass."""
|
||||||
@ -116,7 +118,6 @@ class MusicCastMediaPlayer(MusicCastDeviceEntity, MediaPlayerEntity):
|
|||||||
self.coordinator.musiccast.remove_group_update_callback(
|
self.coordinator.musiccast.remove_group_update_callback(
|
||||||
self.update_all_mc_entities
|
self.update_all_mc_entities
|
||||||
)
|
)
|
||||||
self.coordinator.async_remove_listener(self.async_schedule_check_client_list)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def should_poll(self):
|
def should_poll(self):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Generator
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
@ -13,7 +13,7 @@ import aiohttp
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||||||
# when it was already checked during setup.
|
# when it was already checked during setup.
|
||||||
self.data: _T = None # type: ignore[assignment]
|
self.data: _T = None # type: ignore[assignment]
|
||||||
|
|
||||||
self._listeners: list[CALLBACK_TYPE] = []
|
self._listeners: dict[CALLBACK_TYPE, tuple[CALLBACK_TYPE, object | None]] = {}
|
||||||
self._job = HassJob(self._handle_refresh_interval)
|
self._job = HassJob(self._handle_refresh_interval)
|
||||||
self._unsub_refresh: CALLBACK_TYPE | None = None
|
self._unsub_refresh: CALLBACK_TYPE | None = None
|
||||||
self._request_refresh_task: asyncio.TimerHandle | None = None
|
self._request_refresh_task: asyncio.TimerHandle | None = None
|
||||||
@ -82,32 +82,46 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||||||
self._debounced_refresh = request_refresh_debouncer
|
self._debounced_refresh = request_refresh_debouncer
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_add_listener(self, update_callback: CALLBACK_TYPE) -> Callable[[], None]:
|
def async_add_listener(
|
||||||
|
self, update_callback: CALLBACK_TYPE, context: Any = None
|
||||||
|
) -> Callable[[], None]:
|
||||||
"""Listen for data updates."""
|
"""Listen for data updates."""
|
||||||
schedule_refresh = not self._listeners
|
schedule_refresh = not self._listeners
|
||||||
|
|
||||||
self._listeners.append(update_callback)
|
@callback
|
||||||
|
def remove_listener() -> None:
|
||||||
|
"""Remove update listener."""
|
||||||
|
self._listeners.pop(remove_listener)
|
||||||
|
if not self._listeners:
|
||||||
|
self._unschedule_refresh()
|
||||||
|
|
||||||
|
self._listeners[remove_listener] = (update_callback, context)
|
||||||
|
|
||||||
# This is the first listener, set up interval.
|
# This is the first listener, set up interval.
|
||||||
if schedule_refresh:
|
if schedule_refresh:
|
||||||
self._schedule_refresh()
|
self._schedule_refresh()
|
||||||
|
|
||||||
@callback
|
|
||||||
def remove_listener() -> None:
|
|
||||||
"""Remove update listener."""
|
|
||||||
self.async_remove_listener(update_callback)
|
|
||||||
|
|
||||||
return remove_listener
|
return remove_listener
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove_listener(self, update_callback: CALLBACK_TYPE) -> None:
|
def async_update_listeners(self) -> None:
|
||||||
"""Remove data update."""
|
"""Update all registered listeners."""
|
||||||
self._listeners.remove(update_callback)
|
for update_callback, _ in list(self._listeners.values()):
|
||||||
|
update_callback()
|
||||||
|
|
||||||
if not self._listeners and self._unsub_refresh:
|
@callback
|
||||||
|
def _unschedule_refresh(self) -> None:
|
||||||
|
"""Unschedule any pending refresh since there is no longer any listeners."""
|
||||||
|
if self._unsub_refresh:
|
||||||
self._unsub_refresh()
|
self._unsub_refresh()
|
||||||
self._unsub_refresh = None
|
self._unsub_refresh = None
|
||||||
|
|
||||||
|
def async_contexts(self) -> Generator[Any, None, None]:
|
||||||
|
"""Return all registered contexts."""
|
||||||
|
yield from (
|
||||||
|
context for _, context in self._listeners.values() if context is not None
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _schedule_refresh(self) -> None:
|
def _schedule_refresh(self) -> None:
|
||||||
"""Schedule a refresh."""
|
"""Schedule a refresh."""
|
||||||
@ -266,8 +280,7 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||||||
if not auth_failed and self._listeners and not self.hass.is_stopping:
|
if not auth_failed and self._listeners and not self.hass.is_stopping:
|
||||||
self._schedule_refresh()
|
self._schedule_refresh()
|
||||||
|
|
||||||
for update_callback in self._listeners:
|
self.async_update_listeners()
|
||||||
update_callback()
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set_updated_data(self, data: _T) -> None:
|
def async_set_updated_data(self, data: _T) -> None:
|
||||||
@ -288,24 +301,18 @@ class DataUpdateCoordinator(Generic[_T]):
|
|||||||
if self._listeners:
|
if self._listeners:
|
||||||
self._schedule_refresh()
|
self._schedule_refresh()
|
||||||
|
|
||||||
for update_callback in self._listeners:
|
self.async_update_listeners()
|
||||||
update_callback()
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_stop_refresh(self, _: Event) -> None:
|
|
||||||
"""Stop refreshing when Home Assistant is stopping."""
|
|
||||||
self.update_interval = None
|
|
||||||
if self._unsub_refresh:
|
|
||||||
self._unsub_refresh()
|
|
||||||
self._unsub_refresh = None
|
|
||||||
|
|
||||||
|
|
||||||
class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
|
class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
|
||||||
"""A class for entities using DataUpdateCoordinator."""
|
"""A class for entities using DataUpdateCoordinator."""
|
||||||
|
|
||||||
def __init__(self, coordinator: _DataUpdateCoordinatorT) -> None:
|
def __init__(
|
||||||
|
self, coordinator: _DataUpdateCoordinatorT, context: Any = None
|
||||||
|
) -> None:
|
||||||
"""Create the entity with a DataUpdateCoordinator."""
|
"""Create the entity with a DataUpdateCoordinator."""
|
||||||
self.coordinator = coordinator
|
self.coordinator = coordinator
|
||||||
|
self.coordinator_context = context
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def should_poll(self) -> bool:
|
def should_poll(self) -> bool:
|
||||||
@ -321,7 +328,9 @@ class CoordinatorEntity(entity.Entity, Generic[_DataUpdateCoordinatorT]):
|
|||||||
"""When entity is added to hass."""
|
"""When entity is added to hass."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
self.coordinator.async_add_listener(self._handle_coordinator_update)
|
self.coordinator.async_add_listener(
|
||||||
|
self._handle_coordinator_update, self.coordinator_context
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -109,11 +109,29 @@ async def test_async_refresh(crd):
|
|||||||
await crd.async_refresh()
|
await crd.async_refresh()
|
||||||
assert updates == [2]
|
assert updates == [2]
|
||||||
|
|
||||||
# Test unsubscribing through method
|
|
||||||
crd.async_add_listener(update_callback)
|
async def test_update_context(crd: update_coordinator.DataUpdateCoordinator[int]):
|
||||||
crd.async_remove_listener(update_callback)
|
"""Test update contexts for the update coordinator."""
|
||||||
await crd.async_refresh()
|
await crd.async_refresh()
|
||||||
assert updates == [2]
|
assert not set(crd.async_contexts())
|
||||||
|
|
||||||
|
def update_callback1():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_callback2():
|
||||||
|
pass
|
||||||
|
|
||||||
|
unsub1 = crd.async_add_listener(update_callback1, 1)
|
||||||
|
assert set(crd.async_contexts()) == {1}
|
||||||
|
|
||||||
|
unsub2 = crd.async_add_listener(update_callback2, 2)
|
||||||
|
assert set(crd.async_contexts()) == {1, 2}
|
||||||
|
|
||||||
|
unsub1()
|
||||||
|
assert set(crd.async_contexts()) == {2}
|
||||||
|
|
||||||
|
unsub2()
|
||||||
|
assert not set(crd.async_contexts())
|
||||||
|
|
||||||
|
|
||||||
async def test_request_refresh(crd):
|
async def test_request_refresh(crd):
|
||||||
@ -191,7 +209,7 @@ async def test_update_interval(hass, crd):
|
|||||||
|
|
||||||
# Add subscriber
|
# Add subscriber
|
||||||
update_callback = Mock()
|
update_callback = Mock()
|
||||||
crd.async_add_listener(update_callback)
|
unsub = crd.async_add_listener(update_callback)
|
||||||
|
|
||||||
# Test twice we update with subscriber
|
# Test twice we update with subscriber
|
||||||
async_fire_time_changed(hass, utcnow() + crd.update_interval)
|
async_fire_time_changed(hass, utcnow() + crd.update_interval)
|
||||||
@ -203,7 +221,7 @@ async def test_update_interval(hass, crd):
|
|||||||
assert crd.data == 2
|
assert crd.data == 2
|
||||||
|
|
||||||
# Test removing listener
|
# Test removing listener
|
||||||
crd.async_remove_listener(update_callback)
|
unsub()
|
||||||
|
|
||||||
async_fire_time_changed(hass, utcnow() + crd.update_interval)
|
async_fire_time_changed(hass, utcnow() + crd.update_interval)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
@ -222,7 +240,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
|
|||||||
|
|
||||||
# Add subscriber
|
# Add subscriber
|
||||||
update_callback = Mock()
|
update_callback = Mock()
|
||||||
crd.async_add_listener(update_callback)
|
unsub = crd.async_add_listener(update_callback)
|
||||||
|
|
||||||
# Test twice we don't update with subscriber with no update interval
|
# Test twice we don't update with subscriber with no update interval
|
||||||
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
|
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
|
||||||
@ -234,7 +252,7 @@ async def test_update_interval_not_present(hass, crd_without_update_interval):
|
|||||||
assert crd.data is None
|
assert crd.data is None
|
||||||
|
|
||||||
# Test removing listener
|
# Test removing listener
|
||||||
crd.async_remove_listener(update_callback)
|
unsub()
|
||||||
|
|
||||||
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
|
async_fire_time_changed(hass, utcnow() + DEFAULT_UPDATE_INTERVAL)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
@ -253,9 +271,10 @@ async def test_refresh_recover(crd, caplog):
|
|||||||
assert "Fetching test data recovered" in caplog.text
|
assert "Fetching test data recovered" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_coordinator_entity(crd):
|
async def test_coordinator_entity(crd: update_coordinator.DataUpdateCoordinator[int]):
|
||||||
"""Test the CoordinatorEntity class."""
|
"""Test the CoordinatorEntity class."""
|
||||||
entity = update_coordinator.CoordinatorEntity(crd)
|
context = object()
|
||||||
|
entity = update_coordinator.CoordinatorEntity(crd, context)
|
||||||
|
|
||||||
assert entity.should_poll is False
|
assert entity.should_poll is False
|
||||||
|
|
||||||
@ -278,6 +297,8 @@ async def test_coordinator_entity(crd):
|
|||||||
await entity.async_update()
|
await entity.async_update()
|
||||||
assert entity.available is False
|
assert entity.available is False
|
||||||
|
|
||||||
|
assert list(crd.async_contexts()) == [context]
|
||||||
|
|
||||||
|
|
||||||
async def test_async_set_updated_data(crd):
|
async def test_async_set_updated_data(crd):
|
||||||
"""Test async_set_updated_data for update coordinator."""
|
"""Test async_set_updated_data for update coordinator."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user