mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Fix race in starting reauth flows (#103130)
This commit is contained in:
parent
d75a6a3b4b
commit
9b27552238
@ -223,6 +223,7 @@ class ConfigEntry:
|
|||||||
"_async_cancel_retry_setup",
|
"_async_cancel_retry_setup",
|
||||||
"_on_unload",
|
"_on_unload",
|
||||||
"reload_lock",
|
"reload_lock",
|
||||||
|
"_reauth_lock",
|
||||||
"_tasks",
|
"_tasks",
|
||||||
"_background_tasks",
|
"_background_tasks",
|
||||||
"_integration_for_domain",
|
"_integration_for_domain",
|
||||||
@ -321,6 +322,8 @@ class ConfigEntry:
|
|||||||
|
|
||||||
# Reload lock to prevent conflicting reloads
|
# Reload lock to prevent conflicting reloads
|
||||||
self.reload_lock = asyncio.Lock()
|
self.reload_lock = asyncio.Lock()
|
||||||
|
# Reauth lock to prevent concurrent reauth flows
|
||||||
|
self._reauth_lock = asyncio.Lock()
|
||||||
|
|
||||||
self._tasks: set[asyncio.Future[Any]] = set()
|
self._tasks: set[asyncio.Future[Any]] = set()
|
||||||
self._background_tasks: set[asyncio.Future[Any]] = set()
|
self._background_tasks: set[asyncio.Future[Any]] = set()
|
||||||
@ -727,12 +730,28 @@ class ConfigEntry:
|
|||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a reauth flow."""
|
"""Start a reauth flow."""
|
||||||
|
# We will check this again in the task when we hold the lock,
|
||||||
|
# but we also check it now to try to avoid creating the task.
|
||||||
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
|
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
|
||||||
# Reauth flow already in progress for this entry
|
# Reauth flow already in progress for this entry
|
||||||
return
|
return
|
||||||
|
|
||||||
hass.async_create_task(
|
hass.async_create_task(
|
||||||
hass.config_entries.flow.async_init(
|
self._async_init_reauth(hass, context, data),
|
||||||
|
f"config entry reauth {self.title} {self.domain} {self.entry_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_init_reauth(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Start a reauth flow."""
|
||||||
|
async with self._reauth_lock:
|
||||||
|
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
|
||||||
|
# Reauth flow already in progress for this entry
|
||||||
|
return
|
||||||
|
await hass.config_entries.flow.async_init(
|
||||||
self.domain,
|
self.domain,
|
||||||
context={
|
context={
|
||||||
"source": SOURCE_REAUTH,
|
"source": SOURCE_REAUTH,
|
||||||
@ -742,9 +761,7 @@ class ConfigEntry:
|
|||||||
}
|
}
|
||||||
| (context or {}),
|
| (context or {}),
|
||||||
data=self.data | (data or {}),
|
data=self.data | (data or {}),
|
||||||
),
|
)
|
||||||
f"config entry reauth {self.title} {self.domain} {self.entry_id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_active_flows(
|
def async_get_active_flows(
|
||||||
@ -754,7 +771,9 @@ class ConfigEntry:
|
|||||||
return (
|
return (
|
||||||
flow
|
flow
|
||||||
for flow in hass.config_entries.flow.async_progress_by_handler(
|
for flow in hass.config_entries.flow.async_progress_by_handler(
|
||||||
self.domain, match_context={"entry_id": self.entry_id}
|
self.domain,
|
||||||
|
match_context={"entry_id": self.entry_id},
|
||||||
|
include_uninitialized=True,
|
||||||
)
|
)
|
||||||
if flow["context"].get("source") in sources
|
if flow["context"].get("source") in sources
|
||||||
)
|
)
|
||||||
|
@ -42,6 +42,7 @@ async def test_setup_auth_failed(
|
|||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
with patch.object(hass.config_entries.flow, "async_init") as mock_flow_init:
|
with patch.object(hass.config_entries.flow, "async_init") as mock_flow_init:
|
||||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
assert config_entry.state is ConfigEntryState.SETUP_ERROR
|
assert config_entry.state is ConfigEntryState.SETUP_ERROR
|
||||||
mock_flow_init.assert_called_with(
|
mock_flow_init.assert_called_with(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
@ -3791,6 +3791,20 @@ async def test_reauth(hass: HomeAssistant) -> None:
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(hass.config_entries.flow.async_progress()) == 2
|
assert len(hass.config_entries.flow.async_progress()) == 2
|
||||||
|
|
||||||
|
# Abort all existing flows
|
||||||
|
for flow in hass.config_entries.flow.async_progress():
|
||||||
|
hass.config_entries.flow.async_abort(flow["flow_id"])
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Check that we can't start duplicate reauth flows
|
||||||
|
# without blocking between flows
|
||||||
|
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
|
||||||
|
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
|
||||||
|
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
|
||||||
|
entry.async_start_reauth(hass, {"extra_context": "some_extra_context"})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(hass.config_entries.flow.async_progress()) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_get_active_flows(hass: HomeAssistant) -> None:
|
async def test_get_active_flows(hass: HomeAssistant) -> None:
|
||||||
"""Test the async_get_active_flows helper."""
|
"""Test the async_get_active_flows helper."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user