From adf84b0c62c4d6472d589b086162563f4f27a854 Mon Sep 17 00:00:00 2001 From: Aaron Bach Date: Wed, 9 Nov 2022 15:36:50 -0700 Subject: [PATCH] Add `async_get_active_reauth_flows` helper for config entries (#81881) * Add `async_get_active_reauth_flows` helper for config entries * Code review * Code review + tests --- .../components/openuv/coordinator.py | 21 ++++-------- homeassistant/config_entries.py | 22 +++++++++---- tests/test_config_entries.py | 32 +++++++++++++++++++ 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/openuv/coordinator.py b/homeassistant/components/openuv/coordinator.py index 36267972f80..1df6a91b398 100644 --- a/homeassistant/components/openuv/coordinator.py +++ b/homeassistant/components/openuv/coordinator.py @@ -7,13 +7,13 @@ from typing import Any, cast from pyopenuv.errors import InvalidApiKeyError, OpenUvError -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed -from .const import DOMAIN, LOGGER +from .const import LOGGER DEFAULT_DEBOUNCER_COOLDOWN_SECONDS = 15 * 60 @@ -56,19 +56,10 @@ class ReauthFlowManager: @callback def _get_active_reauth_flow(self) -> FlowResult | None: """Get an active reauth flow (if it exists).""" - try: - [reauth_flow] = [ - flow - for flow in self.hass.config_entries.flow.async_progress_by_handler( - DOMAIN - ) - if flow["context"]["source"] == "reauth" - and flow["context"]["entry_id"] == self.entry.entry_id - ] - except ValueError: - return None - - return reauth_flow + return next( + iter(self.entry.async_get_active_flows(self.hass, {SOURCE_REAUTH})), + None, + ) @callback def cancel_reauth(self) -> None: diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 902fa0d03f2..ddef5d7f226 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from collections import ChainMap -from collections.abc import Callable, Coroutine, Iterable, Mapping +from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping from contextvars import ContextVar from enum import Enum import functools @@ -19,6 +19,7 @@ from .backports.enum import StrEnum from .components import persistent_notification from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback +from .data_entry_flow import FlowResult from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError from .helpers import device_registry, entity_registry, storage from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send @@ -662,12 +663,7 @@ class ConfigEntry: data: dict[str, Any] | None = None, ) -> None: """Start a reauth flow.""" - if any( - flow - for flow in hass.config_entries.flow.async_progress_by_handler(self.domain) - if flow["context"].get("source") == SOURCE_REAUTH - and flow["context"].get("entry_id") == self.entry_id - ): + if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})): # Reauth flow already in progress for this entry return @@ -685,6 +681,18 @@ class ConfigEntry: ) ) + @callback + def async_get_active_flows( + self, hass: HomeAssistant, sources: set[str] + ) -> Generator[FlowResult, None, None]: + """Get any active flows of certain sources for this entry.""" + return ( + flow + for flow in hass.config_entries.flow.async_progress_by_handler(self.domain) + if flow["context"].get("source") in sources + and flow["context"].get("entry_id") == self.entry_id + ) + @callback def async_create_task( self, hass: HomeAssistant, target: Coroutine[Any, Any, _R] diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 96d032f771e..28c3f9c2803 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3429,3 +3429,35 @@ async def test_wait_for_loading_timeout(hass: HomeAssistant) -> None: }, timeout=0.1, ) + + +async def test_get_active_flows(hass): + """Test the async_get_active_flows helper.""" + entry = MockConfigEntry(title="test_title", domain="test") + mock_setup_entry = AsyncMock(return_value=True) + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_entity_platform(hass, "config_flow.test", None) + + await entry.async_setup(hass) + await hass.async_block_till_done() + + flow = hass.config_entries.flow + with patch.object(flow, "async_init", wraps=flow.async_init): + entry.async_start_reauth( + hass, + context={"extra_context": "some_extra_context"}, + data={"extra_data": 1234}, + ) + await hass.async_block_till_done() + + # Check that there's an active reauth flow: + active_reauth_flow = next( + iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_REAUTH})), None + ) + assert active_reauth_flow is not None + + # Check that there isn't any other flow (in this case, a user flow): + active_user_flow = next( + iter(entry.async_get_active_flows(hass, {config_entries.SOURCE_USER})), None + ) + assert active_user_flow is None