Ensure config entry setup lock is held when removing a config entry (#117086)

This commit is contained in:
J. Nick Koston 2024-05-10 19:47:26 -05:00 committed by GitHub
parent c74c2f3652
commit 2e60e09ba2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 16 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
from collections.abc import Mapping
from datetime import timedelta
from math import ceil
@ -307,15 +306,19 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# domain:
new_entry_data = {**entry.data}
new_entry_data.pop(CONF_INTEGRATION_TYPE)
tasks = [
# Schedule the removal in a task to avoid a deadlock
# since we cannot remove a config entry that is in
# the process of being setup.
hass.async_create_background_task(
hass.config_entries.async_remove(entry.entry_id),
hass.config_entries.flow.async_init(
DOMAIN_AIRVISUAL_PRO,
context={"source": SOURCE_IMPORT},
data=new_entry_data,
),
]
await asyncio.gather(*tasks)
name="remove config legacy airvisual entry {entry.title}",
)
await hass.config_entries.flow.async_init(
DOMAIN_AIRVISUAL_PRO,
context={"source": SOURCE_IMPORT},
data=new_entry_data,
)
# After the migration has occurred, grab the new config and device entries
# (now under the `airvisual_pro` domain):

View File

@ -1621,15 +1621,16 @@ class ConfigEntries:
if (entry := self.async_get_entry(entry_id)) is None:
raise UnknownEntry
if not entry.state.recoverable:
unload_success = entry.state is not ConfigEntryState.FAILED_UNLOAD
else:
unload_success = await self.async_unload(entry_id)
async with entry.setup_lock:
if not entry.state.recoverable:
unload_success = entry.state is not ConfigEntryState.FAILED_UNLOAD
else:
unload_success = await self.async_unload(entry_id)
await entry.async_remove(self.hass)
await entry.async_remove(self.hass)
del self._entries[entry.entry_id]
self._async_schedule_save()
del self._entries[entry.entry_id]
self._async_schedule_save()
dev_reg = device_registry.async_get(self.hass)
ent_reg = entity_registry.async_get(self.hass)