Compare commits

...

11 Commits

Author SHA1 Message Date
Martin Hjelmare
ee3be39b21 Fix whitespace 2025-11-04 23:22:54 +01:00
Martin Hjelmare
3c4b82c93a Add comment explaining check for storing result 2025-11-04 23:21:03 +01:00
Martin Hjelmare
59d9eff803 Improve docstring of progress step data property 2025-11-04 23:20:33 +01:00
Martin Hjelmare
a3e15f9412 Remove progress step result reset 2025-11-04 23:12:29 +01:00
Martin Hjelmare
3af200c7f8 Fix first task done, second task show progress 2025-10-29 16:27:17 +01:00
Martin Hjelmare
c11c2856be Fix mutable object as default 2025-10-28 15:10:02 +01:00
Martin Hjelmare
939464836a test_progress_step_done_abort 2025-10-28 14:54:15 +01:00
Martin Hjelmare
ae72e93ce9 Add test parameter and case comments 2025-10-28 14:54:15 +01:00
Martin Hjelmare
cb54b1e4a7 Fix data_entry_flow recursion 2025-10-28 14:54:15 +01:00
Martin Hjelmare
866b1b5406 Test chaining progress steps 2025-10-28 14:54:15 +01:00
Martin Hjelmare
7e5293a699 Test progress step 2025-10-28 14:54:14 +01:00
2 changed files with 457 additions and 34 deletions

View File

@@ -645,12 +645,24 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
__progress_task: asyncio.Task[Any] | None = None
__no_progress_task_reported = False
deprecated_show_progress = False
_progress_step_data: ProgressStepData[_FlowResultT] = {
"tasks": {},
"abort_reason": "",
"abort_description_placeholders": MappingProxyType({}),
"next_step_result": None,
}
__progress_step_data: ProgressStepData[_FlowResultT] | None = None
@property
def _progress_step_data(self) -> ProgressStepData[_FlowResultT]:
"""Return progress step data.
A property is used instead of a simple attribute as derived classes
do not call super().__init__.
The property makes sure that the dict is initialized if needed.
"""
if not self.__progress_step_data:
self.__progress_step_data = {
"tasks": {},
"abort_reason": "",
"abort_description_placeholders": MappingProxyType({}),
"next_step_result": None,
}
return self.__progress_step_data
@property
def source(self) -> str | None:
@@ -777,9 +789,10 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
self, user_input: dict[str, Any] | None = None
) -> _FlowResultT:
"""Abort the flow."""
progress_step_data = self._progress_step_data
return self.async_abort(
reason=self._progress_step_data["abort_reason"],
description_placeholders=self._progress_step_data[
reason=progress_step_data["abort_reason"],
description_placeholders=progress_step_data[
"abort_description_placeholders"
],
)
@@ -795,14 +808,15 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
without using async_show_progress_done.
If no next step is set, abort the flow.
"""
if self._progress_step_data["next_step_result"] is None:
progress_step_data = self._progress_step_data
if (next_step_result := progress_step_data["next_step_result"]) is None:
return self.async_abort(
reason=self._progress_step_data["abort_reason"],
description_placeholders=self._progress_step_data[
reason=progress_step_data["abort_reason"],
description_placeholders=progress_step_data[
"abort_description_placeholders"
],
)
return self._progress_step_data["next_step_result"]
return next_step_result
@callback
def async_external_step(
@@ -1021,9 +1035,9 @@ def progress_step[
self: FlowHandler[Any, ResultT], *args: P.args, **kwargs: P.kwargs
) -> ResultT:
step_id = func.__name__.replace("async_step_", "")
progress_step_data = self._progress_step_data
# Check if we have a progress task running
progress_task = self._progress_step_data["tasks"].get(step_id)
progress_task = progress_step_data["tasks"].get(step_id)
if progress_task is None:
# First call - create and start the progress task
@@ -1031,30 +1045,30 @@ def progress_step[
func(self, *args, **kwargs), # type: ignore[arg-type]
f"Progress step {step_id}",
)
self._progress_step_data["tasks"][step_id] = progress_task
progress_step_data["tasks"][step_id] = progress_task
if not progress_task.done():
# Handle description placeholders
placeholders = None
if description_placeholders is not None:
if callable(description_placeholders):
placeholders = description_placeholders(self)
else:
placeholders = description_placeholders
if not progress_task.done():
# Handle description placeholders
placeholders = None
if description_placeholders is not None:
if callable(description_placeholders):
placeholders = description_placeholders(self)
else:
placeholders = description_placeholders
return self.async_show_progress(
step_id=step_id,
progress_action=step_id,
progress_task=progress_task,
description_placeholders=placeholders,
)
return self.async_show_progress(
step_id=step_id,
progress_action=step_id,
progress_task=progress_task,
description_placeholders=placeholders,
)
# Task is done or this is a subsequent call
try:
self._progress_step_data["next_step_result"] = await progress_task
progress_task_result = await progress_task
except AbortFlow as err:
self._progress_step_data["abort_reason"] = err.reason
self._progress_step_data["abort_description_placeholders"] = (
progress_step_data["abort_reason"] = err.reason
progress_step_data["abort_description_placeholders"] = (
err.description_placeholders or {}
)
return self.async_show_progress_done(
@@ -1062,7 +1076,14 @@ def progress_step[
)
finally:
# Clean up task reference
self._progress_step_data["tasks"].pop(step_id, None)
progress_step_data["tasks"].pop(step_id, None)
# If the result type is FlowResultType.SHOW_PROGRESS_DONE
# an earlier show progress step has already been run and stored its result.
# In this case we should not overwrite the result,
# but just use the stored one.
if progress_task_result["type"] != FlowResultType.SHOW_PROGRESS_DONE:
progress_step_data["next_step_result"] = progress_task_result
return self.async_show_progress_done(
next_step_id="_progress_step_progress_done"

View File

@@ -1,9 +1,11 @@
"""Test the flow classes."""
import asyncio
from collections.abc import Callable
import dataclasses
import logging
from unittest.mock import Mock, patch
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
import voluptuous as vol
@@ -930,6 +932,406 @@ async def test_show_progress_fires_only_when_changed(
) # change (description placeholder)
@pytest.mark.parametrize(
("task_side_effect", "flow_result"),
[
(None, data_entry_flow.FlowResultType.CREATE_ENTRY),
(data_entry_flow.AbortFlow("fail"), data_entry_flow.FlowResultType.ABORT),
],
)
@pytest.mark.parametrize(
("description", "expected_description"),
[
(None, None),
({"title": "World"}, {"title": "World"}),
(lambda x: {"title": "World"}, {"title": "World"}),
],
)
async def test_progress_step(
hass: HomeAssistant,
manager: MockFlowManager,
description: Callable[[data_entry_flow.FlowHandler], dict[str, Any]]
| dict[str, Any]
| None,
expected_description: dict[str, Any] | None,
task_side_effect: Exception | None,
flow_result: data_entry_flow.FlowResultType,
) -> None:
"""Test progress_step decorator."""
manager.hass = hass
events = []
task_init_evt = asyncio.Event()
event_received_evt = asyncio.Event()
task_result = Mock()
task_result.side_effect = task_side_effect
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
@data_entry_flow.progress_step(description_placeholders=description)
async def async_step_init(self, user_input=None):
await task_init_evt.wait()
task_result()
return await self.async_step_finish()
async def async_step_finish(self, user_input=None):
return self.async_create_entry(data={})
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "init"
description_placeholders = result["description_placeholders"]
assert description_placeholders == expected_description
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
task_init_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == flow_result
@pytest.mark.parametrize(
(
"task_init_side_effect", # side effect for initial step task
"task_next_side_effect", # side effect for next step task
"flow_result_before_init", # result before init task is done
"flow_result_after_init", # result after init task is done
"flow_result_after_next", # result after next task is done
"flow_init_events", # number of events fired after init task is done
"flow_next_events", # number of events fired after next task is done
"manager_call_after_init", # lambda to continue the flow after init task
"manager_call_after_next", # lambda to continue the flow after next task
"before_init_task_side_effect", # function called before init event
"before_next_task_side_effect", # function called before next event
),
[
( # both steps show progress and complete successfully
None,
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.CREATE_ENTRY,
1,
2,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda received_event, init_task_event, next_task_event: None,
lambda received_event, init_task_event, next_task_event: None,
),
( # first step aborts
data_entry_flow.AbortFlow("fail"),
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.ABORT,
data_entry_flow.FlowResultType.ABORT,
1,
1,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: AsyncMock(return_value=result)(),
lambda received_event, init_task_event, next_task_event: None,
lambda received_event, init_task_event, next_task_event: None,
),
( # first step shows progress, second step aborts
None,
data_entry_flow.AbortFlow("fail"),
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.ABORT,
1,
2,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda received_event, init_task_event, next_task_event: None,
lambda received_event, init_task_event, next_task_event: None,
),
( # first step shows progress and second step task is already done
None,
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.CREATE_ENTRY,
data_entry_flow.FlowResultType.CREATE_ENTRY,
1,
1,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: AsyncMock(return_value=result)(),
lambda received_event,
init_task_event,
next_task_event: next_task_event.set(),
lambda received_event, init_task_event, next_task_event: None,
),
( # both step tasks are already done and flow completes immediately
None,
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS_DONE,
data_entry_flow.FlowResultType.CREATE_ENTRY,
data_entry_flow.FlowResultType.CREATE_ENTRY,
0,
0,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: AsyncMock(return_value=result)(),
lambda received_event,
init_task_event,
next_task_event: received_event.set()
or init_task_event.set()
or next_task_event.set(),
lambda received_event,
init_task_event,
next_task_event: received_event.set(),
),
( # first step task is already done, second step shows progress and completes
None,
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS_DONE,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.CREATE_ENTRY,
0,
1,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda received_event,
init_task_event,
next_task_event: received_event.set() or init_task_event.set(),
lambda received_event, init_task_event, next_task_event: None,
),
],
)
async def test_chaining_progress_steps(
hass: HomeAssistant,
manager: MockFlowManager,
task_init_side_effect: Exception | None,
task_next_side_effect: Exception | None,
flow_result_before_init: data_entry_flow.FlowResultType,
flow_result_after_init: data_entry_flow.FlowResultType,
flow_result_after_next: data_entry_flow.FlowResultType,
flow_init_events: int,
flow_next_events: int,
manager_call_after_init: Callable[
[MockFlowManager, data_entry_flow.FlowResult], Any
],
manager_call_after_next: Callable[
[MockFlowManager, data_entry_flow.FlowResult], Any
],
before_init_task_side_effect: Callable[
[asyncio.Event, asyncio.Event, asyncio.Event], None
],
before_next_task_side_effect: Callable[
[asyncio.Event, asyncio.Event, asyncio.Event], None
],
) -> None:
"""Test chaining two steps with progress_step decorators."""
manager.hass = hass
events = []
event_received_evt = asyncio.Event()
task_init_evt = asyncio.Event()
task_next_evt = asyncio.Event()
task_init_result = Mock()
task_init_result.side_effect = task_init_side_effect
task_next_result = Mock()
task_next_result.side_effect = task_next_side_effect
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
def async_remove(self) -> None:
# Disable event received event to allow test to finish if flow is aborted.
event_received_evt.set()
@data_entry_flow.progress_step()
async def async_step_init(self, user_input=None):
await task_init_evt.wait()
task_init_result()
return await self.async_step_next()
@data_entry_flow.progress_step()
async def async_step_next(self, user_input=None):
await task_next_evt.wait()
task_next_result()
return await self.async_step_finish()
async def async_step_finish(self, user_input=None):
return self.async_create_entry(data={})
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
)
# Run side effect before first event is awaited
before_init_task_side_effect(event_received_evt, task_init_evt, task_next_evt)
result = await manager.async_init("test")
assert result["type"] == flow_result_before_init
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task init done and wait for event
task_init_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == flow_init_events
# Run side effect before second event is awaited
before_next_task_side_effect(event_received_evt, task_init_evt, task_next_evt)
# Continue the flow if needed.
result = await manager_call_after_init(manager, result)
assert result["type"] == flow_result_after_init
# Set task next done and wait for event
task_next_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == flow_next_events
# Continue the flow if needed.
result = await manager_call_after_next(manager, result)
assert result["type"] == flow_result_after_next
async def test_progress_step_done_abort(
hass: HomeAssistant,
manager: MockFlowManager,
) -> None:
"""Test progress_step decorator without done result set."""
manager.hass = hass
events = []
@callback
def capture_events(event: Event) -> None:
events.append(event)
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
@data_entry_flow.progress_step()
async def async_step_init(self, user_input=None):
# async_show_progress_done
return data_entry_flow.FlowResult(
flow_id=self.flow_id,
handler=self.handler,
type=data_entry_flow.FlowResultType.SHOW_PROGRESS_DONE,
)
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS_DONE
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == ""
assert len(manager.async_progress()) == 0
assert len(manager.async_progress_by_handler("test")) == 0
assert not events
async def test_progress_step_result_reset(
hass: HomeAssistant,
manager: MockFlowManager,
) -> None:
"""Test progress_step decorator with reset result."""
manager.hass = hass
events = []
task_init_evt = asyncio.Event()
event_received_evt = asyncio.Event()
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
@data_entry_flow.progress_step()
async def async_step_init(self, user_input=None):
await task_init_evt.wait()
return await self.async_step_finish()
async def async_step_finish(self, user_input=None):
if user_input is None:
return self.async_show_form(step_id="finish")
return self.async_create_entry(data={})
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
)
first_result = await manager.async_init("test")
assert first_result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert first_result["progress_action"] == "init"
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(first_result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
task_init_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": first_result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(first_result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "finish"
# Continue the flow again from the first result to test idempotency.
result = await manager.async_configure(first_result["flow_id"])
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "finish"
# Finish the flow
result = await manager.async_configure(first_result["flow_id"], {})
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
async def test_abort_flow_exception_step(manager: MockFlowManager) -> None:
"""Test that the AbortFlow exception works in a step."""