diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index aa9df89de5c..65cf7eb3d36 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -74,6 +74,10 @@ FLOW_NOT_COMPLETE_STEPS = { FlowResultType.MENU, } +STEP_ID_OPTIONAL_STEPS = { + FlowResultType.SHOW_PROGRESS, +} + @dataclass(slots=True) class BaseServiceInfo: @@ -458,6 +462,10 @@ class FlowManager(abc.ABC): elif result["type"] != FlowResultType.SHOW_PROGRESS: flow.async_cancel_progress_task() + if result["type"] in STEP_ID_OPTIONAL_STEPS: + if "step_id" not in result: + result["step_id"] = step_id + if result["type"] in FLOW_NOT_COMPLETE_STEPS: self._raise_if_step_does_not_exist(flow, result["step_id"]) flow.cur_step = result @@ -654,21 +662,23 @@ class FlowHandler: def async_show_progress( self, *, - step_id: str, + step_id: str | None = None, progress_action: str, description_placeholders: Mapping[str, str] | None = None, progress_task: asyncio.Task[Any] | None = None, ) -> FlowResult: """Show a progress message to the user, without user input allowed.""" - return FlowResult( + result = FlowResult( type=FlowResultType.SHOW_PROGRESS, flow_id=self.flow_id, handler=self.handler, - step_id=step_id, progress_action=progress_action, description_placeholders=description_placeholders, progress_task=progress_task, ) + if step_id is not None: + result["step_id"] = step_id + return result @callback def async_show_progress_done(self, *, next_step_id: str) -> FlowResult: diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index aedf3e40c15..cafeaaf3ba0 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -380,7 +380,6 @@ async def test_show_progress(hass: HomeAssistant, manager) -> None: self.start_task_two = False if not task_one_evt.is_set() or not task_two_evt.is_set(): return self.async_show_progress( - step_id="init", progress_action=progress_action, progress_task=self.progress_task, ) @@ -464,7 +463,7 @@ async def test_show_progress_error(hass: HomeAssistant, manager) -> None: return self.async_show_progress_done(next_step_id="error") return self.async_show_progress_done(next_step_id="no_error") return self.async_show_progress( - step_id="init", progress_action="task", progress_task=self.progress_task + progress_action="task", progress_task=self.progress_task ) async def async_step_error(self, user_input=None):