Prevent leak of current_entry context variable (#128145)

This commit is contained in:
epenet 2024-10-16 18:02:37 +02:00 committed by GitHub
parent 494511e099
commit 350a27575f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 1 deletions

View File

@ -529,10 +529,21 @@ class ConfigEntry(Generic[_DataT]):
integration: loader.Integration | None = None,
) -> None:
"""Set up an entry."""
current_entry.set(self)
if self.source == SOURCE_IGNORE or self.disabled_by:
return
current_entry.set(self)
try:
await self.__async_setup_with_context(hass, integration)
finally:
current_entry.set(None)
async def __async_setup_with_context(
self,
hass: HomeAssistant,
integration: loader.Integration | None,
) -> None:
"""Set up an entry, with current_entry set."""
if integration is None and not (integration := self._integration_for_domain):
integration = await loader.async_get_integration(hass, self.domain)
self._integration_for_domain = integration

View File

@ -6937,3 +6937,72 @@ async def test_async_update_entry_unique_id_collision(
"Unique id of config entry 'Mock Title' from integration test changed to "
"'very unique' which is already in use"
) in caplog.text
async def test_context_no_leak(hass: HomeAssistant) -> None:
"""Test ensure that config entry context does not leak.
Unlikely to happen in real world, but occurs often in tests.
"""
connected_future = asyncio.Future()
bg_tasks = []
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Mock setup entry."""
async def _async_set_runtime_data():
# Show that config_entries.current_entry is preserved for child tasks
await connected_future
entry.runtime_data = config_entries.current_entry.get()
bg_tasks.append(hass.loop.create_task(_async_set_runtime_data()))
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Mock unload entry."""
return True
mock_integration(
hass,
MockModule(
"comp",
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry,
),
)
mock_platform(hass, "comp.config_flow", None)
entry1 = MockConfigEntry(domain="comp")
entry1.add_to_hass(hass)
await hass.config_entries.async_setup(entry1.entry_id)
assert entry1.state is config_entries.ConfigEntryState.LOADED
assert config_entries.current_entry.get() is None
# Load an existing config entry
entry2 = MockConfigEntry(domain="comp")
entry2.add_to_hass(hass)
await hass.config_entries.async_setup(entry2.entry_id)
assert entry2.state is config_entries.ConfigEntryState.LOADED
assert config_entries.current_entry.get() is None
# Add a new config entry (eg. from config flow)
entry3 = MockConfigEntry(domain="comp")
await hass.config_entries.async_add(entry3)
assert entry3.state is config_entries.ConfigEntryState.LOADED
assert config_entries.current_entry.get() is None
for entry in (entry1, entry2, entry3):
assert entry.state is config_entries.ConfigEntryState.LOADED
assert not hasattr(entry, "runtime_data")
assert config_entries.current_entry.get() is None
connected_future.set_result(None)
await asyncio.gather(*bg_tasks)
for entry in (entry1, entry2, entry3):
assert entry.state is config_entries.ConfigEntryState.LOADED
assert entry.runtime_data is entry
assert config_entries.current_entry.get() is None