Index in-progress flows to avoid linear search (#58146)

Co-authored-by: Steven Looman <steven.looman@gmail.com>
This commit is contained in:
J. Nick Koston 2021-10-22 07:19:49 -10:00 committed by GitHub
parent fa56be7cc0
commit 3b7dce8b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 190 additions and 64 deletions

View File

@ -231,14 +231,9 @@ class LoginFlowResourceView(HomeAssistantView):
try: try:
# do not allow change ip during login flow # do not allow change ip during login flow
for flow in self._flow_mgr.async_progress(): flow = self._flow_mgr.async_get(flow_id)
if flow["flow_id"] == flow_id and flow["context"][ if flow["context"]["ip_address"] != ip_address(request.remote):
"ip_address" return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST)
] != ip_address(request.remote):
return self.json_message(
"IP address changed", HTTPStatus.BAD_REQUEST
)
result = await self._flow_mgr.async_configure(flow_id, data) result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow: except data_entry_flow.UnknownFlow:
return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND) return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND)

View File

@ -131,7 +131,7 @@ class PointFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.debug( _LOGGER.debug(
"Should close all flows below %s", "Should close all flows below %s",
self.hass.config_entries.flow.async_progress(), self._async_in_progress(),
) )
# Remove notification if no other discovery config entries in progress # Remove notification if no other discovery config entries in progress

View File

@ -73,8 +73,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Remove the entry which will invoke the callback to delete the app. # Remove the entry which will invoke the callback to delete the app.
hass.async_create_task(hass.config_entries.async_remove(entry.entry_id)) hass.async_create_task(hass.config_entries.async_remove(entry.entry_id))
# only create new flow if there isn't a pending one for SmartThings. # only create new flow if there isn't a pending one for SmartThings.
flows = hass.config_entries.flow.async_progress() if not hass.config_entries.flow.async_progress_by_handler(DOMAIN):
if not [flow for flow in flows if flow["handler"] == DOMAIN]:
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT} DOMAIN, context={"source": SOURCE_IMPORT}
@ -181,8 +180,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if remove_entry: if remove_entry:
hass.async_create_task(hass.config_entries.async_remove(entry.entry_id)) hass.async_create_task(hass.config_entries.async_remove(entry.entry_id))
# only create new flow if there isn't a pending one for SmartThings. # only create new flow if there isn't a pending one for SmartThings.
flows = hass.config_entries.flow.async_progress() if not hass.config_entries.flow.async_progress_by_handler(DOMAIN):
if not [flow for flow in flows if flow["handler"] == DOMAIN]:
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT} DOMAIN, context={"source": SOURCE_IMPORT}

View File

@ -406,8 +406,8 @@ async def _continue_flow(
flow = next( flow = next(
( (
flow flow
for flow in hass.config_entries.flow.async_progress() for flow in hass.config_entries.flow.async_progress_by_handler(DOMAIN)
if flow["handler"] == DOMAIN and flow["context"]["unique_id"] == unique_id if flow["context"]["unique_id"] == unique_id
), ),
None, None,
) )

View File

@ -745,7 +745,9 @@ class DataManager:
flow = next( flow = next(
iter( iter(
flow flow
for flow in self._hass.config_entries.flow.async_progress() for flow in self._hass.config_entries.flow.async_progress_by_handler(
const.DOMAIN
)
if flow.context == context if flow.context == context
), ),
None, None,

View File

@ -120,8 +120,7 @@ class ZhaFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
# If they already have a discovery for deconz # If they already have a discovery for deconz
# we ignore the usb discovery as they probably # we ignore the usb discovery as they probably
# want to use it there instead # want to use it there instead
for flow in self.hass.config_entries.flow.async_progress(): if self.hass.config_entries.flow.async_progress_by_handler(DECONZ_DOMAIN):
if flow["handler"] == DECONZ_DOMAIN:
return self.async_abort(reason="not_zha_device") return self.async_abort(reason="not_zha_device")
for entry in self.hass.config_entries.async_entries(DECONZ_DOMAIN): for entry in self.hass.config_entries.async_entries(DECONZ_DOMAIN):
if entry.source != config_entries.SOURCE_IGNORE: if entry.source != config_entries.SOURCE_IGNORE:

View File

@ -586,7 +586,7 @@ class ConfigEntry:
"unique_id": self.unique_id, "unique_id": self.unique_id,
} }
for flow in hass.config_entries.flow.async_progress(): for flow in hass.config_entries.flow.async_progress_by_handler(self.domain):
if flow["context"] == flow_context: if flow["context"] == flow_context:
return return
@ -618,6 +618,14 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self.config_entries = config_entries self.config_entries = config_entries
self._hass_config = hass_config self._hass_config = hass_config
@callback
def _async_has_other_discovery_flows(self, flow_id: str) -> bool:
"""Check if there are any other discovery flows in progress."""
return any(
flow.context["source"] in DISCOVERY_SOURCES and flow.flow_id != flow_id
for flow in self._progress.values()
)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
@ -625,11 +633,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
flow = cast(ConfigFlow, flow) flow = cast(ConfigFlow, flow)
# Remove notification if no other discovery config entries in progress # Remove notification if no other discovery config entries in progress
if not any( if not self._async_has_other_discovery_flows(flow.flow_id):
ent["context"]["source"] in DISCOVERY_SOURCES
for ent in self.hass.config_entries.flow.async_progress()
if ent["flow_id"] != flow.flow_id
):
self.hass.components.persistent_notification.async_dismiss( self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID DISCOVERY_NOTIFICATION_ID
) )
@ -642,15 +646,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
# Abort all flows in progress with same unique ID # Abort all flows in progress with same unique ID
# or the default discovery ID # or the default discovery ID
for progress_flow in self.async_progress(): for progress_flow in self.async_progress_by_handler(flow.handler):
progress_unique_id = progress_flow["context"].get("unique_id") progress_unique_id = progress_flow["context"].get("unique_id")
if ( if progress_flow["flow_id"] != flow.flow_id and (
progress_flow["handler"] == flow.handler
and progress_flow["flow_id"] != flow.flow_id
and (
(flow.unique_id and progress_unique_id == flow.unique_id) (flow.unique_id and progress_unique_id == flow.unique_id)
or progress_unique_id == DEFAULT_DISCOVERY_UNIQUE_ID or progress_unique_id == DEFAULT_DISCOVERY_UNIQUE_ID
)
): ):
self.async_abort(progress_flow["flow_id"]) self.async_abort(progress_flow["flow_id"])
@ -837,7 +837,9 @@ class ConfigEntries:
# If the configuration entry is removed during reauth, it should # If the configuration entry is removed during reauth, it should
# abort any reauth flow that is active for the removed entry. # abort any reauth flow that is active for the removed entry.
for progress_flow in self.hass.config_entries.flow.async_progress(): for progress_flow in self.hass.config_entries.flow.async_progress_by_handler(
entry.domain
):
context = progress_flow.get("context") context = progress_flow.get("context")
if ( if (
context context
@ -1265,10 +1267,10 @@ class ConfigFlow(data_entry_flow.FlowHandler):
"""Return other in progress flows for current domain.""" """Return other in progress flows for current domain."""
return [ return [
flw flw
for flw in self.hass.config_entries.flow.async_progress( for flw in self.hass.config_entries.flow.async_progress_by_handler(
include_uninitialized=include_uninitialized self.handler, include_uninitialized=include_uninitialized
) )
if flw["handler"] == self.handler and flw["flow_id"] != self.flow_id if flw["flow_id"] != self.flow_id
] ]
async def async_step_ignore( async def async_step_ignore(
@ -1329,7 +1331,9 @@ class ConfigFlow(data_entry_flow.FlowHandler):
# Remove reauth notification if no reauth flows are in progress # Remove reauth notification if no reauth flows are in progress
if self.source == SOURCE_REAUTH and not any( if self.source == SOURCE_REAUTH and not any(
ent["context"]["source"] == SOURCE_REAUTH ent["context"]["source"] == SOURCE_REAUTH
for ent in self.hass.config_entries.flow.async_progress() for ent in self.hass.config_entries.flow.async_progress_by_handler(
self.handler
)
if ent["flow_id"] != self.flow_id if ent["flow_id"] != self.flow_id
): ):
self.hass.components.persistent_notification.async_dismiss( self.hass.components.persistent_notification.async_dismiss(

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import abc import abc
import asyncio import asyncio
from collections.abc import Mapping from collections.abc import Iterable, Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, TypedDict from typing import Any, TypedDict
import uuid import uuid
@ -78,6 +78,23 @@ class FlowResult(TypedDict, total=False):
options: Mapping[str, Any] options: Mapping[str, Any]
@callback
def _async_flow_handler_to_flow_result(
flows: Iterable[FlowHandler], include_uninitialized: bool
) -> list[FlowResult]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
return [
{
"flow_id": flow.flow_id,
"handler": flow.handler,
"context": flow.context,
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
}
for flow in flows
if include_uninitialized or flow.cur_step is not None
]
class FlowManager(abc.ABC): class FlowManager(abc.ABC):
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
@ -89,7 +106,8 @@ class FlowManager(abc.ABC):
self.hass = hass self.hass = hass
self._initializing: dict[str, list[asyncio.Future]] = {} self._initializing: dict[str, list[asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {} self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._progress: dict[str, Any] = {} self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None: async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized.""" """Wait till all flows in progress are initialized."""
@ -127,24 +145,39 @@ class FlowManager(abc.ABC):
"""Check if an existing matching flow is in progress with the same handler, context, and data.""" """Check if an existing matching flow is in progress with the same handler, context, and data."""
return any( return any(
flow flow
for flow in self._progress.values() for flow in self._async_progress_by_handler(handler)
if flow.handler == handler if flow.context["source"] == context["source"] and flow.init_data == data
and flow.context["source"] == context["source"]
and flow.init_data == data
) )
@callback
def async_get(self, flow_id: str) -> FlowResult | None:
"""Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
return _async_flow_handler_to_flow_result([flow], False)[0]
@callback @callback
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]: def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]:
"""Return the flows in progress.""" """Return the flows in progress as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
self._progress.values(), include_uninitialized
)
@callback
def async_progress_by_handler(
self, handler: str, include_uninitialized: bool = False
) -> list[FlowResult]:
"""Return the flows in progress by handler as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
self._async_progress_by_handler(handler), include_uninitialized
)
@callback
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
"""Return the flows in progress by handler."""
return [ return [
{ self._progress[flow_id]
"flow_id": flow.flow_id, for flow_id in self._handler_progress_index.get(handler, {})
"handler": flow.handler,
"context": flow.context,
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
}
for flow in self._progress.values()
if include_uninitialized or flow.cur_step is not None
] ]
async def async_init( async def async_init(
@ -187,7 +220,7 @@ class FlowManager(abc.ABC):
flow.flow_id = uuid.uuid4().hex flow.flow_id = uuid.uuid4().hex
flow.context = context flow.context = context
flow.init_data = data flow.init_data = data
self._progress[flow.flow_id] = flow self._async_add_flow_progress(flow)
result = await self._async_handle_step(flow, flow.init_step, data, init_done) result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result return flow, result
@ -205,6 +238,7 @@ class FlowManager(abc.ABC):
raise UnknownFlow raise UnknownFlow
cur_step = flow.cur_step cur_step = flow.cur_step
assert cur_step is not None
if cur_step.get("data_schema") is not None and user_input is not None: if cur_step.get("data_schema") is not None and user_input is not None:
user_input = cur_step["data_schema"](user_input) user_input = cur_step["data_schema"](user_input)
@ -245,8 +279,24 @@ class FlowManager(abc.ABC):
@callback @callback
def async_abort(self, flow_id: str) -> None: def async_abort(self, flow_id: str) -> None:
"""Abort a flow.""" """Abort a flow."""
if self._progress.pop(flow_id, None) is None: self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
"""Add a flow to in progress."""
self._progress[flow.flow_id] = flow
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
@callback
def _async_remove_flow_progress(self, flow_id: str) -> None:
"""Remove a flow from in progress."""
flow = self._progress.pop(flow_id, None)
if flow is None:
raise UnknownFlow raise UnknownFlow
handler = flow.handler
self._handler_progress_index[handler].remove(flow.flow_id)
if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler]
async def _async_handle_step( async def _async_handle_step(
self, self,
@ -259,7 +309,7 @@ class FlowManager(abc.ABC):
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
if not hasattr(flow, method): if not hasattr(flow, method):
self._progress.pop(flow.flow_id) self._async_remove_flow_progress(flow.flow_id)
if step_done: if step_done:
step_done.set_result(None) step_done.set_result(None)
raise UnknownStep( raise UnknownStep(
@ -310,7 +360,7 @@ class FlowManager(abc.ABC):
return result return result
# Abort and Success results both finish the flow # Abort and Success results both finish the flow
self._progress.pop(flow.flow_id) self._async_remove_flow_progress(flow.flow_id)
return result return result
@ -319,7 +369,7 @@ class FlowHandler:
"""Handle the configuration flow of a component.""" """Handle the configuration flow of a component."""
# Set by flow manager # Set by flow manager
cur_step: dict[str, str] | None = None cur_step: dict[str, Any] | None = None
# While not purely typed, it makes typehinting more useful for us # While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts. # and removes the need for constant None checks or asserts.

View File

@ -114,3 +114,43 @@ async def test_login_exist_user(hass, aiohttp_client):
step = await resp.json() step = await resp.json()
assert step["type"] == "create_entry" assert step["type"] == "create_entry"
assert len(step["result"]) > 1 assert len(step["result"]) > 1
async def test_login_exist_user_ip_changes(hass, aiohttp_client):
"""Test logging in and the ip address changes results in an rejection."""
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
cred = await hass.auth.auth_providers[0].async_get_or_create_credentials(
{"username": "test-user"}
)
await hass.auth.async_get_or_create_user(cred)
resp = await client.post(
"/auth/login_flow",
json={
"client_id": CLIENT_ID,
"handler": ["insecure_example", None],
"redirect_uri": CLIENT_REDIRECT_URI,
},
)
assert resp.status == 200
step = await resp.json()
#
# Here we modify the ip_address in the context to make sure
# when ip address changes in the middle of the login flow we prevent logins.
#
# This method was chosen because it seemed less likely to break
# vs patching aiohttp internals to fake the ip address
#
for flow_id, flow in hass.auth.login_flow._progress.items():
assert flow_id == step["flow_id"]
flow.context["ip_address"] = "10.2.3.1"
resp = await client.post(
f"/auth/login_flow/{step['flow_id']}",
json={"client_id": CLIENT_ID, "username": "test-user", "password": "test-pass"},
)
assert resp.status == 400
response = await resp.json()
assert response == {"message": "IP address changed"}

View File

@ -349,7 +349,7 @@ async def test_remove_entry_cancels_reauth(hass, manager):
await entry.async_setup(hass) await entry.async_setup(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 1 assert len(flows) == 1
assert flows[0]["context"]["entry_id"] == entry.entry_id assert flows[0]["context"]["entry_id"] == entry.entry_id
assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH
@ -357,7 +357,7 @@ async def test_remove_entry_cancels_reauth(hass, manager):
await manager.async_remove(entry.entry_id) await manager.async_remove(entry.entry_id)
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 0 assert len(flows) == 0
@ -2100,11 +2100,11 @@ async def test_unignore_step_form(hass, manager):
# Right after removal there shouldn't be an entry or active flows # Right after removal there shouldn't be an entry or active flows
assert len(hass.config_entries.async_entries("comp")) == 0 assert len(hass.config_entries.async_entries("comp")) == 0
assert len(hass.config_entries.flow.async_progress()) == 0 assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0
# But after a 'tick' the unignore step has run and we can see an active flow again. # But after a 'tick' the unignore step has run and we can see an active flow again.
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass.config_entries.flow.async_progress()) == 1 assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 1
# and still not config entries # and still not config entries
assert len(hass.config_entries.async_entries("comp")) == 0 assert len(hass.config_entries.async_entries("comp")) == 0
@ -2144,7 +2144,7 @@ async def test_unignore_create_entry(hass, manager):
await manager.async_remove(entry.entry_id) await manager.async_remove(entry.entry_id)
# Right after removal there shouldn't be an entry or flow # Right after removal there shouldn't be an entry or flow
assert len(hass.config_entries.flow.async_progress()) == 0 assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0
assert len(hass.config_entries.async_entries("comp")) == 0 assert len(hass.config_entries.async_entries("comp")) == 0
# But after a 'tick' the unignore step has run and we can see a config entry. # But after a 'tick' the unignore step has run and we can see a config entry.
@ -2155,7 +2155,7 @@ async def test_unignore_create_entry(hass, manager):
assert entry.title == "yo" assert entry.title == "yo"
# And still no active flow # And still no active flow
assert len(hass.config_entries.flow.async_progress()) == 0 assert len(hass.config_entries.flow.async_progress_by_handler("comp")) == 0
async def test_unignore_default_impl(hass, manager): async def test_unignore_default_impl(hass, manager):

View File

@ -271,6 +271,8 @@ async def test_external_step(hass, manager):
result = await manager.async_init("test") result = await manager.async_init("test")
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert len(manager.async_progress()) == 1 assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Mimic external step # Mimic external step
# Called by integrations: `hass.config_entries.flow.async_configure(…)` # Called by integrations: `hass.config_entries.flow.async_configure(…)`
@ -327,6 +329,8 @@ async def test_show_progress(hass, manager):
assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS
assert result["progress_action"] == "task_one" assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1 assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Mimic task one done and moving to task two # Mimic task one done and moving to task two
# Called by integrations: `hass.config_entries.flow.async_configure(…)` # Called by integrations: `hass.config_entries.flow.async_configure(…)`
@ -400,6 +404,13 @@ async def test_init_unknown_flow(manager):
await manager.async_init("test") await manager.async_init("test")
async def test_async_get_unknown_flow(manager):
"""Test that UnknownFlow is raised when async_get is called with a flow_id that does not exist."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_get("does_not_exist")
async def test_async_has_matching_flow( async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.FlowManager hass: HomeAssistant, manager: data_entry_flow.FlowManager
): ):
@ -424,6 +435,8 @@ async def test_async_has_matching_flow(
assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS
assert result["progress_action"] == "task_one" assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1 assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
assert ( assert (
manager.async_has_matching_flow( manager.async_has_matching_flow(
@ -449,3 +462,28 @@ async def test_async_has_matching_flow(
) )
is False is False
) )
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(manager):
"""Test that moving to an unknown step raises and removes the flow from in progress."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
with pytest.raises(data_entry_flow.UnknownStep):
await manager.async_init("test", context={"init_step": "does_not_exist"})
assert manager.async_progress() == []
async def test_configure_raises_unknown_flow_if_not_in_progress(manager):
"""Test configure raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_configure("wrong_flow_id")
async def test_abort_raises_unknown_flow_if_not_in_progress(manager):
"""Test abort raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_abort("wrong_flow_id")