mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Add ability to get callback when a config entry state changes (#138943)
* Add entry_on_state_change_helper * undo black * remove unload * no coro * Add tests * Don't accept coro * Review feedback * Add error test * Make it callback type * Make it callback type * Removal test * change type
This commit is contained in:
parent
b35d252549
commit
e59ec8f867
@ -402,6 +402,7 @@ class ConfigEntry[_DataT = Any]:
|
|||||||
update_listeners: list[UpdateListenerType]
|
update_listeners: list[UpdateListenerType]
|
||||||
_async_cancel_retry_setup: Callable[[], Any] | None
|
_async_cancel_retry_setup: Callable[[], Any] | None
|
||||||
_on_unload: list[Callable[[], Coroutine[Any, Any, None] | None]] | None
|
_on_unload: list[Callable[[], Coroutine[Any, Any, None] | None]] | None
|
||||||
|
_on_state_change: list[CALLBACK_TYPE] | None
|
||||||
setup_lock: asyncio.Lock
|
setup_lock: asyncio.Lock
|
||||||
_reauth_lock: asyncio.Lock
|
_reauth_lock: asyncio.Lock
|
||||||
_tasks: set[asyncio.Future[Any]]
|
_tasks: set[asyncio.Future[Any]]
|
||||||
@ -526,6 +527,9 @@ class ConfigEntry[_DataT = Any]:
|
|||||||
# Hold list for actions to call on unload.
|
# Hold list for actions to call on unload.
|
||||||
_setter(self, "_on_unload", None)
|
_setter(self, "_on_unload", None)
|
||||||
|
|
||||||
|
# Hold list for actions to call on state change.
|
||||||
|
_setter(self, "_on_state_change", None)
|
||||||
|
|
||||||
# Reload lock to prevent conflicting reloads
|
# Reload lock to prevent conflicting reloads
|
||||||
_setter(self, "setup_lock", asyncio.Lock())
|
_setter(self, "setup_lock", asyncio.Lock())
|
||||||
# Reauth lock to prevent concurrent reauth flows
|
# Reauth lock to prevent concurrent reauth flows
|
||||||
@ -1058,6 +1062,8 @@ class ConfigEntry[_DataT = Any]:
|
|||||||
hass, SIGNAL_CONFIG_ENTRY_CHANGED, ConfigEntryChange.UPDATED, self
|
hass, SIGNAL_CONFIG_ENTRY_CHANGED, ConfigEntryChange.UPDATED, self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._async_process_on_state_change()
|
||||||
|
|
||||||
async def async_migrate(self, hass: HomeAssistant) -> bool:
|
async def async_migrate(self, hass: HomeAssistant) -> bool:
|
||||||
"""Migrate an entry.
|
"""Migrate an entry.
|
||||||
|
|
||||||
@ -1172,6 +1178,28 @@ class ConfigEntry[_DataT = Any]:
|
|||||||
task,
|
task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_on_state_change(self, func: CALLBACK_TYPE) -> CALLBACK_TYPE:
|
||||||
|
"""Add a function to call when a config entry changes its state."""
|
||||||
|
if self._on_state_change is None:
|
||||||
|
self._on_state_change = []
|
||||||
|
self._on_state_change.append(func)
|
||||||
|
return lambda: cast(list, self._on_state_change).remove(func)
|
||||||
|
|
||||||
|
def _async_process_on_state_change(self) -> None:
|
||||||
|
"""Process the on_state_change callbacks and wait for pending tasks."""
|
||||||
|
if self._on_state_change is None:
|
||||||
|
return
|
||||||
|
for func in self._on_state_change:
|
||||||
|
try:
|
||||||
|
func()
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception(
|
||||||
|
"Error calling on_state_change callback for %s (%s)",
|
||||||
|
self.title,
|
||||||
|
self.domain,
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_start_reauth(
|
def async_start_reauth(
|
||||||
self,
|
self,
|
||||||
|
@ -4796,6 +4796,136 @@ async def test_entry_reload_calls_on_unload_listeners(
|
|||||||
assert entry.state is config_entries.ConfigEntryState.LOADED
|
assert entry.state is config_entries.ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("source_state", "target_state", "transition_method_name", "call_count"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
config_entries.ConfigEntryState.NOT_LOADED,
|
||||||
|
config_entries.ConfigEntryState.LOADED,
|
||||||
|
"async_setup",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
config_entries.ConfigEntryState.LOADED,
|
||||||
|
config_entries.ConfigEntryState.NOT_LOADED,
|
||||||
|
"async_unload",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
config_entries.ConfigEntryState.LOADED,
|
||||||
|
config_entries.ConfigEntryState.LOADED,
|
||||||
|
"async_reload",
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_entry_state_change_calls_listener(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
manager: config_entries.ConfigEntries,
|
||||||
|
source_state: config_entries.ConfigEntryState,
|
||||||
|
target_state: config_entries.ConfigEntryState,
|
||||||
|
transition_method_name: str,
|
||||||
|
call_count: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test listeners get called on entry state changes."""
|
||||||
|
entry = MockConfigEntry(domain="comp", state=source_state)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
"comp",
|
||||||
|
async_setup=AsyncMock(return_value=True),
|
||||||
|
async_setup_entry=AsyncMock(return_value=True),
|
||||||
|
async_unload_entry=AsyncMock(return_value=True),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_platform(hass, "comp.config_flow", None)
|
||||||
|
hass.config.components.add("comp")
|
||||||
|
|
||||||
|
mock_state_change_callback = Mock()
|
||||||
|
entry.async_on_state_change(mock_state_change_callback)
|
||||||
|
|
||||||
|
transition_method = getattr(manager, transition_method_name)
|
||||||
|
await transition_method(entry.entry_id)
|
||||||
|
|
||||||
|
assert len(mock_state_change_callback.mock_calls) == call_count
|
||||||
|
assert entry.state is target_state
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entry_state_change_listener_removed(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
manager: config_entries.ConfigEntries,
|
||||||
|
) -> None:
|
||||||
|
"""Test state_change listener can be removed."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain="comp", state=config_entries.ConfigEntryState.NOT_LOADED
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
"comp",
|
||||||
|
async_setup=AsyncMock(return_value=True),
|
||||||
|
async_setup_entry=AsyncMock(return_value=True),
|
||||||
|
async_unload_entry=AsyncMock(return_value=True),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_platform(hass, "comp.config_flow", None)
|
||||||
|
hass.config.components.add("comp")
|
||||||
|
|
||||||
|
mock_state_change_callback = Mock()
|
||||||
|
remove = entry.async_on_state_change(mock_state_change_callback)
|
||||||
|
|
||||||
|
await manager.async_setup(entry.entry_id)
|
||||||
|
|
||||||
|
assert len(mock_state_change_callback.mock_calls) == 2
|
||||||
|
assert entry.state is config_entries.ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
remove()
|
||||||
|
|
||||||
|
await manager.async_unload(entry.entry_id)
|
||||||
|
|
||||||
|
# the listener should no longer be called
|
||||||
|
assert len(mock_state_change_callback.mock_calls) == 2
|
||||||
|
assert entry.state is config_entries.ConfigEntryState.NOT_LOADED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entry_state_change_error_does_not_block_transition(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
manager: config_entries.ConfigEntries,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test we transition states normally even if the callback throws in on_state_change."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
title="test", domain="comp", state=config_entries.ConfigEntryState.NOT_LOADED
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
mock_integration(
|
||||||
|
hass,
|
||||||
|
MockModule(
|
||||||
|
"comp",
|
||||||
|
async_setup=AsyncMock(return_value=True),
|
||||||
|
async_setup_entry=AsyncMock(return_value=True),
|
||||||
|
async_unload_entry=AsyncMock(return_value=True),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_platform(hass, "comp.config_flow", None)
|
||||||
|
hass.config.components.add("comp")
|
||||||
|
|
||||||
|
mock_state_change_callback = Mock(side_effect=Exception())
|
||||||
|
|
||||||
|
entry.async_on_state_change(mock_state_change_callback)
|
||||||
|
|
||||||
|
await manager.async_setup(entry.entry_id)
|
||||||
|
|
||||||
|
assert len(mock_state_change_callback.mock_calls) == 2
|
||||||
|
assert entry.state is config_entries.ConfigEntryState.LOADED
|
||||||
|
assert "Error calling on_state_change callback for test (comp)" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_raise_entry_error(
|
async def test_setup_raise_entry_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
manager: config_entries.ConfigEntries,
|
manager: config_entries.ConfigEntries,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user