Fix memory leak when unloading DataUpdateCoordinator (#137338)

* check wiz

* Fix memory leak when unloading DataUpdateCoordinator

fixes #137237

* handle namespace conflict

* handle namespace conflict

* address review comments
This commit is contained in:
J. Nick Koston 2025-02-05 02:29:23 -06:00 committed by GitHub
parent 03de3aec15
commit 3fc13db7e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 11 deletions

View File

@ -6,6 +6,7 @@ from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Generator from collections.abc import Awaitable, Callable, Coroutine, Generator
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
import logging import logging
from random import randint from random import randint
from time import monotonic from time import monotonic
@ -103,7 +104,8 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
randint(event.RANDOM_MICROSECOND_MIN, event.RANDOM_MICROSECOND_MAX) / 10**6 randint(event.RANDOM_MICROSECOND_MIN, event.RANDOM_MICROSECOND_MAX) / 10**6
) )
self._listeners: dict[CALLBACK_TYPE, tuple[CALLBACK_TYPE, object | None]] = {} self._listeners: dict[int, tuple[CALLBACK_TYPE, object | None]] = {}
self._last_listener_id: int = 0
self._unsub_refresh: CALLBACK_TYPE | None = None self._unsub_refresh: CALLBACK_TYPE | None = None
self._unsub_shutdown: CALLBACK_TYPE | None = None self._unsub_shutdown: CALLBACK_TYPE | None = None
self._request_refresh_task: asyncio.TimerHandle | None = None self._request_refresh_task: asyncio.TimerHandle | None = None
@ -148,21 +150,26 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Listen for data updates.""" """Listen for data updates."""
schedule_refresh = not self._listeners schedule_refresh = not self._listeners
self._last_listener_id += 1
@callback self._listeners[self._last_listener_id] = (update_callback, context)
def remove_listener() -> None:
"""Remove update listener."""
self._listeners.pop(remove_listener)
if not self._listeners:
self._unschedule_refresh()
self._listeners[remove_listener] = (update_callback, context)
# This is the first listener, set up interval. # This is the first listener, set up interval.
if schedule_refresh: if schedule_refresh:
self._schedule_refresh() self._schedule_refresh()
return remove_listener return partial(self.__async_remove_listener_internal, self._last_listener_id)
@callback
def __async_remove_listener_internal(self, listener_id: int) -> None:
"""Remove a listener.
This is an internal function that is not to be overridden
in subclasses as it may change in the future.
"""
self._listeners.pop(listener_id)
if not self._listeners:
self._unschedule_refresh()
self._debounced_refresh.async_cancel()
@callback @callback
def async_update_listeners(self) -> None: def async_update_listeners(self) -> None:

View File

@ -2,6 +2,7 @@
from datetime import timedelta from datetime import timedelta
from unittest.mock import MagicMock from unittest.mock import MagicMock
import weakref
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
from homewizard_energy.errors import DisabledError, UnauthorizedError from homewizard_energy.errors import DisabledError, UnauthorizedError
@ -25,6 +26,9 @@ async def test_load_unload_v1(
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
weak_ref = weakref.ref(mock_config_entry.runtime_data)
assert weak_ref() is not None
assert mock_config_entry.state is ConfigEntryState.LOADED assert mock_config_entry.state is ConfigEntryState.LOADED
assert len(mock_homewizardenergy.combined.mock_calls) == 1 assert len(mock_homewizardenergy.combined.mock_calls) == 1
@ -32,6 +36,7 @@ async def test_load_unload_v1(
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.NOT_LOADED assert mock_config_entry.state is ConfigEntryState.NOT_LOADED
assert weak_ref() is None
async def test_load_unload_v2( async def test_load_unload_v2(