Add async_get_active_reauth_flows helper for config entries (#81881)

* Add `async_get_active_reauth_flows` helper for config entries

* Code review

* Code review + tests
This commit is contained in:
Aaron Bach 2022-11-09 15:36:50 -07:00 committed by GitHub
parent 0941ed076c
commit adf84b0c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 22 deletions

View File

@ -7,13 +7,13 @@ from typing import Any, cast
from pyopenuv.errors import InvalidApiKeyError, OpenUvError
from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import DOMAIN, LOGGER
from .const import LOGGER
DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60
@ -56,19 +56,10 @@ class ReauthFlowManager:
@callback
def _get_active_reauth_flow(self) -> FlowResult | None:
"""Get an active reauth flow (if it exists)."""
try:
[reauth_flow] = [
flow
for flow in self.hass.config_entries.flow.async_progress_by_handler(
DOMAIN
return next(
iter(self.entry.async_get_active_flows(self.hass, {SOURCE_REAUTH})),
None,
)
if flow["context"]["source"] == "reauth"
and flow["context"]["entry_id"] == self.entry.entry_id
]
except ValueError:
return None
return reauth_flow
@callback
def cancel_reauth(self) -> None:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from collections import ChainMap
from collections.abc import Callable, Coroutine, Iterable, Mapping
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
from contextvars import ContextVar
from enum import Enum
import functools
@ -19,6 +19,7 @@ from .backports.enum import StrEnum
from .components import persistent_notification
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
from .data_entry_flow import FlowResult
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
from .helpers import device_registry, entity_registry, storage
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
@ -662,12 +663,7 @@ class ConfigEntry:
data: dict[str, Any] | None = None,
) -> None:
"""Start a reauth flow."""
if any(
flow
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
if flow["context"].get("source") == SOURCE_REAUTH
and flow["context"].get("entry_id") == self.entry_id
):
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
# Reauth flow already in progress for this entry
return
@ -685,6 +681,18 @@ class ConfigEntry:
)
)
@callback
def async_get_active_flows(
self, hass: HomeAssistant, sources: set[str]
) -> Generator[FlowResult, None, None]:
"""Get any active flows of certain sources for this entry."""
return (
flow
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
if flow["context"].get("source") in sources
and flow["context"].get("entry_id") == self.entry_id
)
@callback
def async_create_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]

View File

@ -3429,3 +3429,35 @@ async def test_wait_for_loading_timeout(hass: HomeAssistant) -> None:
},
timeout=0.1,
)
async def test_get_active_flows(hass):
"""Test the async_get_active_flows helper."""
entry = MockConfigEntry(title="test_title", domain="test")
mock_setup_entry = AsyncMock(return_value=True)
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
mock_entity_platform(hass, "config_flow.test", None)
await entry.async_setup(hass)
await hass.async_block_till_done()
flow = hass.config_entries.flow
with patch.object(flow, "async_init", wraps=flow.async_init):
entry.async_start_reauth(
hass,
context={"extra_context": "some_extra_context"},
data={"extra_data": 1234},
)
await hass.async_block_till_done()
# Check that there's an active reauth flow:
active_reauth_flow = next(
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_REAUTH})), None
)
assert active_reauth_flow is not None
# Check that there isn't any other flow (in this case, a user flow):
active_user_flow = next(
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_USER})), None
)
assert active_user_flow is None