mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 16:27:08 +00:00
Add async_get_active_reauth_flows
helper for config entries (#81881)
* Add `async_get_active_reauth_flows` helper for config entries * Code review * Code review + tests
This commit is contained in:
parent
0941ed076c
commit
adf84b0c62
@ -7,13 +7,13 @@ from typing import Any, cast
|
|||||||
|
|
||||||
from pyopenuv.errors import InvalidApiKeyError, OpenUvError
|
from pyopenuv.errors import InvalidApiKeyError, OpenUvError
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers.debounce import Debouncer
|
from homeassistant.helpers.debounce import Debouncer
|
||||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||||
|
|
||||||
from .const import DOMAIN, LOGGER
|
from .const import LOGGER
|
||||||
|
|
||||||
DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60
|
DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60
|
||||||
|
|
||||||
@ -56,19 +56,10 @@ class ReauthFlowManager:
|
|||||||
@callback
|
@callback
|
||||||
def _get_active_reauth_flow(self) -> FlowResult | None:
|
def _get_active_reauth_flow(self) -> FlowResult | None:
|
||||||
"""Get an active reauth flow (if it exists)."""
|
"""Get an active reauth flow (if it exists)."""
|
||||||
try:
|
return next(
|
||||||
[reauth_flow] = [
|
iter(self.entry.async_get_active_flows(self.hass, {SOURCE_REAUTH})),
|
||||||
flow
|
None,
|
||||||
for flow in self.hass.config_entries.flow.async_progress_by_handler(
|
)
|
||||||
DOMAIN
|
|
||||||
)
|
|
||||||
if flow["context"]["source"] == "reauth"
|
|
||||||
and flow["context"]["entry_id"] == self.entry.entry_id
|
|
||||||
]
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return reauth_flow
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def cancel_reauth(self) -> None:
|
def cancel_reauth(self) -> None:
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import ChainMap
|
from collections import ChainMap
|
||||||
from collections.abc import Callable, Coroutine, Iterable, Mapping
|
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import functools
|
import functools
|
||||||
@ -19,6 +19,7 @@ from .backports.enum import StrEnum
|
|||||||
from .components import persistent_notification
|
from .components import persistent_notification
|
||||||
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
|
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
|
||||||
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||||
|
from .data_entry_flow import FlowResult
|
||||||
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
|
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
|
||||||
from .helpers import device_registry, entity_registry, storage
|
from .helpers import device_registry, entity_registry, storage
|
||||||
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
|
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
|
||||||
@ -662,12 +663,7 @@ class ConfigEntry:
|
|||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a reauth flow."""
|
"""Start a reauth flow."""
|
||||||
if any(
|
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
|
||||||
flow
|
|
||||||
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
|
|
||||||
if flow["context"].get("source") == SOURCE_REAUTH
|
|
||||||
and flow["context"].get("entry_id") == self.entry_id
|
|
||||||
):
|
|
||||||
# Reauth flow already in progress for this entry
|
# Reauth flow already in progress for this entry
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -685,6 +681,18 @@ class ConfigEntry:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_active_flows(
|
||||||
|
self, hass: HomeAssistant, sources: set[str]
|
||||||
|
) -> Generator[FlowResult, None, None]:
|
||||||
|
"""Get any active flows of certain sources for this entry."""
|
||||||
|
return (
|
||||||
|
flow
|
||||||
|
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
|
||||||
|
if flow["context"].get("source") in sources
|
||||||
|
and flow["context"].get("entry_id") == self.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_task(
|
def async_create_task(
|
||||||
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]
|
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]
|
||||||
|
@ -3429,3 +3429,35 @@ async def test_wait_for_loading_timeout(hass: HomeAssistant) -> None:
|
|||||||
},
|
},
|
||||||
timeout=0.1,
|
timeout=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_active_flows(hass):
|
||||||
|
"""Test the async_get_active_flows helper."""
|
||||||
|
entry = MockConfigEntry(title="test_title", domain="test")
|
||||||
|
mock_setup_entry = AsyncMock(return_value=True)
|
||||||
|
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
|
||||||
|
mock_entity_platform(hass, "config_flow.test", None)
|
||||||
|
|
||||||
|
await entry.async_setup(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
flow = hass.config_entries.flow
|
||||||
|
with patch.object(flow, "async_init", wraps=flow.async_init):
|
||||||
|
entry.async_start_reauth(
|
||||||
|
hass,
|
||||||
|
context={"extra_context": "some_extra_context"},
|
||||||
|
data={"extra_data": 1234},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Check that there's an active reauth flow:
|
||||||
|
active_reauth_flow = next(
|
||||||
|
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_REAUTH})), None
|
||||||
|
)
|
||||||
|
assert active_reauth_flow is not None
|
||||||
|
|
||||||
|
# Check that there isn't any other flow (in this case, a user flow):
|
||||||
|
active_user_flow = next(
|
||||||
|
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_USER})), None
|
||||||
|
)
|
||||||
|
assert active_user_flow is None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user