diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index adbb2f80f64..14d69e278fa 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -745,9 +745,10 @@ class ConfigEntry: """Get any active flows of certain sources for this entry.""" return ( flow - for flow in hass.config_entries.flow.async_progress_by_handler(self.domain) + for flow in hass.config_entries.flow.async_progress_by_handler( + self.domain, match_context={"entry_id": self.entry_id} + ) if flow["context"].get("source") in sources - and flow["context"].get("entry_id") == self.entry_id ) @callback @@ -1086,16 +1087,9 @@ class ConfigEntries: # If the configuration entry is removed during reauth, it should # abort any reauth flow that is active for the removed entry. for progress_flow in self.hass.config_entries.flow.async_progress_by_handler( - entry.domain + entry.domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH} ): - context = progress_flow.get("context") - if ( - context - and context["source"] == SOURCE_REAUTH - and "entry_id" in context - and context["entry_id"] == entry_id - and "flow_id" in progress_flow - ): + if "flow_id" in progress_flow: self.hass.config_entries.flow.async_abort(progress_flow["flow_id"]) # After we have fully removed an "ignore" config entry we can try and rediscover @@ -1577,17 +1571,20 @@ class ConfigFlow(data_entry_flow.FlowHandler): return None if raise_on_progress: - for progress in self._async_in_progress(include_uninitialized=True): - if progress["context"].get("unique_id") == unique_id: - raise data_entry_flow.AbortFlow("already_in_progress") + if self._async_in_progress( + include_uninitialized=True, match_context={"unique_id": unique_id} + ): + raise data_entry_flow.AbortFlow("already_in_progress") self.context["unique_id"] = unique_id # Abort discoveries done using the default discovery unique id if unique_id != DEFAULT_DISCOVERY_UNIQUE_ID: - for progress in self._async_in_progress(include_uninitialized=True): - if progress["context"].get("unique_id") == DEFAULT_DISCOVERY_UNIQUE_ID: - self.hass.config_entries.flow.async_abort(progress["flow_id"]) + for progress in self._async_in_progress( + include_uninitialized=True, + match_context={"unique_id": DEFAULT_DISCOVERY_UNIQUE_ID}, + ): + self.hass.config_entries.flow.async_abort(progress["flow_id"]) for entry in self._async_current_entries(include_ignore=True): if entry.unique_id == unique_id: @@ -1633,13 +1630,17 @@ class ConfigFlow(data_entry_flow.FlowHandler): @callback def _async_in_progress( - self, include_uninitialized: bool = False + self, + include_uninitialized: bool = False, + match_context: dict[str, Any] | None = None, ) -> list[data_entry_flow.FlowResult]: """Return other in progress flows for current domain.""" return [ flw for flw in self.hass.config_entries.flow.async_progress_by_handler( - self.handler, include_uninitialized=include_uninitialized + self.handler, + include_uninitialized=include_uninitialized, + match_context=match_context, ) if flw["flow_id"] != self.flow_id ] @@ -1713,11 +1714,10 @@ class ConfigFlow(data_entry_flow.FlowHandler): """Abort the config flow.""" # Remove reauth notification if no reauth flows are in progress if self.source == SOURCE_REAUTH and not any( - ent["context"]["source"] == SOURCE_REAUTH + ent["flow_id"] != self.flow_id for ent in self.hass.config_entries.flow.async_progress_by_handler( - self.handler + self.handler, match_context={"source": SOURCE_REAUTH} ) - if ent["flow_id"] != self.flow_id ): persistent_notification.async_dismiss( self.hass, RECONFIGURE_NOTIFICATION_ID diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index e213814f52c..6f125ce359a 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -164,16 +164,19 @@ class FlowManager(abc.ABC): @callback def async_has_matching_flow( - self, handler: str, context: dict[str, Any], data: Any + self, handler: str, match_context: dict[str, Any], data: Any ) -> bool: """Check if an existing matching flow is in progress. A flow with the same handler, context, and data. + + If match_context is passed, only return flows with a context that is a + superset of match_context. """ return any( flow - for flow in self._async_progress_by_handler(handler) - if flow.context["source"] == context["source"] and flow.init_data == data + for flow in self._async_progress_by_handler(handler, match_context) + if flow.init_data == data ) @callback @@ -192,11 +195,19 @@ class FlowManager(abc.ABC): @callback def async_progress_by_handler( - self, handler: str, include_uninitialized: bool = False + self, + handler: str, + include_uninitialized: bool = False, + match_context: dict[str, Any] | None = None, ) -> list[FlowResult]: - """Return the flows in progress by handler as a partial FlowResult.""" + """Return the flows in progress by handler as a partial FlowResult. + + If match_context is specified, only return flows with a context that + is a superset of match_context. + """ return _async_flow_handler_to_flow_result( - self._async_progress_by_handler(handler), include_uninitialized + self._async_progress_by_handler(handler, match_context), + include_uninitialized, ) @callback @@ -217,11 +228,26 @@ class FlowManager(abc.ABC): ) @callback - def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]: - """Return the flows in progress by handler.""" + def _async_progress_by_handler( + self, handler: str, match_context: dict[str, Any] | None + ) -> list[FlowHandler]: + """Return the flows in progress by handler. + + If match_context is specified, only return flows with a context that + is a superset of match_context. + """ + match_context_items = match_context.items() if match_context else None return [ - self._progress[flow_id] + progress for flow_id in self._handler_progress_index.get(handler, {}) + if (progress := self._progress[flow_id]) + and ( + not match_context_items + or ( + (context := progress.context) + and match_context_items <= context.items() + ) + ) ] async def async_init( diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index c3afc3bc8ba..168f97ba779 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -462,6 +462,22 @@ async def test_async_has_matching_flow( assert result["progress_action"] == "task_one" assert len(manager.async_progress()) == 1 assert len(manager.async_progress_by_handler("test")) == 1 + assert ( + len( + manager.async_progress_by_handler( + "test", match_context={"source": config_entries.SOURCE_HOMEKIT} + ) + ) + == 1 + ) + assert ( + len( + manager.async_progress_by_handler( + "test", match_context={"source": config_entries.SOURCE_BLUETOOTH} + ) + ) + == 0 + ) assert manager.async_get(result["flow_id"])["handler"] == "test" assert (