Add async_get_active_reauth_flows helper for config entries (#81881)

* Add `async_get_active_reauth_flows` helper for config entries

* Code review

* Code review + tests
This commit is contained in:
Aaron Bach 2022-11-09 15:36:50 -07:00 committed by GitHub
parent 0941ed076c
commit adf84b0c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 22 deletions

View File

@ -7,13 +7,13 @@ from typing import Any, cast
from pyopenuv.errors import InvalidApiKeyError, OpenUvError from pyopenuv.errors import InvalidApiKeyError, OpenUvError
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import DOMAIN, LOGGER from .const import LOGGER
DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60 DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60
@ -56,19 +56,10 @@ class ReauthFlowManager:
@callback @callback
def _get_active_reauth_flow(self) -> FlowResult | None: def _get_active_reauth_flow(self) -> FlowResult | None:
"""Get an active reauth flow (if it exists).""" """Get an active reauth flow (if it exists)."""
try: return next(
[reauth_flow] = [ iter(self.entry.async_get_active_flows(self.hass, {SOURCE_REAUTH})),
flow None,
for flow in self.hass.config_entries.flow.async_progress_by_handler( )
DOMAIN
)
if flow["context"]["source"] == "reauth"
and flow["context"]["entry_id"] == self.entry.entry_id
]
except ValueError:
return None
return reauth_flow
@callback @callback
def cancel_reauth(self) -> None: def cancel_reauth(self) -> None:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import ChainMap from collections import ChainMap
from collections.abc import Callable, Coroutine, Iterable, Mapping from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
from contextvars import ContextVar from contextvars import ContextVar
from enum import Enum from enum import Enum
import functools import functools
@ -19,6 +19,7 @@ from .backports.enum import StrEnum
from .components import persistent_notification from .components import persistent_notification
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
from .data_entry_flow import FlowResult
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
from .helpers import device_registry, entity_registry, storage from .helpers import device_registry, entity_registry, storage
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
@ -662,12 +663,7 @@ class ConfigEntry:
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Start a reauth flow.""" """Start a reauth flow."""
if any( if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
flow
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
if flow["context"].get("source") == SOURCE_REAUTH
and flow["context"].get("entry_id") == self.entry_id
):
# Reauth flow already in progress for this entry # Reauth flow already in progress for this entry
return return
@ -685,6 +681,18 @@ class ConfigEntry:
) )
) )
@callback
def async_get_active_flows(
self, hass: HomeAssistant, sources: set[str]
) -> Generator[FlowResult, None, None]:
"""Get any active flows of certain sources for this entry."""
return (
flow
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
if flow["context"].get("source") in sources
and flow["context"].get("entry_id") == self.entry_id
)
@callback @callback
def async_create_task( def async_create_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R] self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]

View File

@ -3429,3 +3429,35 @@ async def test_wait_for_loading_timeout(hass: HomeAssistant) -> None:
}, },
timeout=0.1, timeout=0.1,
) )
async def test_get_active_flows(hass):
"""Test the async_get_active_flows helper."""
entry = MockConfigEntry(title="test_title", domain="test")
mock_setup_entry = AsyncMock(return_value=True)
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
mock_entity_platform(hass, "config_flow.test", None)
await entry.async_setup(hass)
await hass.async_block_till_done()
flow = hass.config_entries.flow
with patch.object(flow, "async_init", wraps=flow.async_init):
entry.async_start_reauth(
hass,
context={"extra_context": "some_extra_context"},
data={"extra_data": 1234},
)
await hass.async_block_till_done()
# Check that there's an active reauth flow:
active_reauth_flow = next(
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_REAUTH})), None
)
assert active_reauth_flow is not None
# Check that there isn't any other flow (in this case, a user flow):
active_user_flow = next(
iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_USER})), None
)
assert active_user_flow is None