mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 08:17:08 +00:00
Add async_get_active_reauth_flows
helper for config entries (#81881)
* Add `async_get_active_reauth_flows` helper for config entries * Code review * Code review + tests
This commit is contained in:
parent
0941ed076c
commit
adf84b0c62
@ -7,13 +7,13 @@ from typing import Any, cast
|
||||
|
||||
from pyopenuv.errors import InvalidApiKeyError, OpenUvError
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers.debounce import Debouncer
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .const import LOGGER
|
||||
|
||||
DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60
|
||||
|
||||
@ -56,19 +56,10 @@ class ReauthFlowManager:
|
||||
@callback
|
||||
def _get_active_reauth_flow(self) -> FlowResult | None:
|
||||
"""Get an active reauth flow (if it exists)."""
|
||||
try:
|
||||
[reauth_flow] = [
|
||||
flow
|
||||
for flow in self.hass.config_entries.flow.async_progress_by_handler(
|
||||
DOMAIN
|
||||
return next(
|
||||
iter(self.entry.async_get_active_flows(self.hass, {SOURCE_REAUTH})),
|
||||
None,
|
||||
)
|
||||
if flow["context"]["source"] == "reauth"
|
||||
and flow["context"]["entry_id"] == self.entry.entry_id
|
||||
]
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return reauth_flow
|
||||
|
||||
@callback
|
||||
def cancel_reauth(self) -> None:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import ChainMap
|
||||
from collections.abc import Callable, Coroutine, Iterable, Mapping
|
||||
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
|
||||
from contextvars import ContextVar
|
||||
from enum import Enum
|
||||
import functools
|
||||
@ -19,6 +19,7 @@ from .backports.enum import StrEnum
|
||||
from .components import persistent_notification
|
||||
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
|
||||
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||
from .data_entry_flow import FlowResult
|
||||
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
|
||||
from .helpers import device_registry, entity_registry, storage
|
||||
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
|
||||
@ -662,12 +663,7 @@ class ConfigEntry:
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Start a reauth flow."""
|
||||
if any(
|
||||
flow
|
||||
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
|
||||
if flow["context"].get("source") == SOURCE_REAUTH
|
||||
and flow["context"].get("entry_id") == self.entry_id
|
||||
):
|
||||
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
|
||||
# Reauth flow already in progress for this entry
|
||||
return
|
||||
|
||||
@ -685,6 +681,18 @@ class ConfigEntry:
|
||||
)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get_active_flows(
|
||||
self, hass: HomeAssistant, sources: set[str]
|
||||
) -> Generator[FlowResult, None, None]:
|
||||
"""Get any active flows of certain sources for this entry."""
|
||||
return (
|
||||
flow
|
||||
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
|
||||
if flow["context"].get("source") in sources
|
||||
and flow["context"].get("entry_id") == self.entry_id
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_create_task(
|
||||
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]
|
||||
|
@ -3429,3 +3429,35 @@ async def test_wait_for_loading_timeout(hass: HomeAssistant) -> None:
|
||||
},
|
||||
timeout=0.1,
|
||||
)
|
||||
|
||||
|
||||
async def test_get_active_flows(hass):
|
||||
"""Test the async_get_active_flows helper."""
|
||||
entry = MockConfigEntry(title="test_title", domain="test")
|
||||
mock_setup_entry = AsyncMock(return_value=True)
|
||||
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
|
||||
mock_entity_platform(hass, "config_flow.test", None)
|
||||
|
||||
await entry.async_setup(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
flow = hass.config_entries.flow
|
||||
with patch.object(flow, "async_init", wraps=flow.async_init):
|
||||
entry.async_start_reauth(
|
||||
hass,
|
||||
context={"extra_context": "some_extra_context"},
|
||||
data={"extra_data": 1234},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Check that there's an active reauth flow:
|
||||
active_reauth_flow = next(
|
||||
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_REAUTH})), None
|
||||
)
|
||||
assert active_reauth_flow is not None
|
||||
|
||||
# Check that there isn't any other flow (in this case, a user flow):
|
||||
active_user_flow = next(
|
||||
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_USER})), None
|
||||
)
|
||||
assert active_user_flow is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user