Improve calls to async_show_progress in improv_ble (#107790)

This commit is contained in:
Erik Montnemery 2024-01-14 09:37:54 +01:00 committed by GitHub
parent 93d363ea57
commit d4cb055d75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 35 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from typing import Any, TypeVar from typing import Any, TypeVar
@ -325,14 +325,15 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return return
if not self._provision_task: if not self._provision_task:
self._provision_task = self.hass.async_create_task( self._provision_task = self.hass.async_create_task(_do_provision())
self._resume_flow_when_done(_do_provision())
) if not self._provision_task.done():
return self.async_show_progress( return self.async_show_progress(
step_id="do_provision", progress_action="provisioning" step_id="do_provision",
progress_action="provisioning",
progress_task=self._provision_task,
) )
await self._provision_task
self._provision_task = None self._provision_task = None
return self.async_show_progress_done(next_step_id="provision_done") return self.async_show_progress_done(next_step_id="provision_done")
@ -347,14 +348,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self._provision_result = None self._provision_result = None
return result return result
async def _resume_flow_when_done(self, awaitable: Awaitable) -> None:
try:
await awaitable
finally:
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(flow_id=self.flow_id)
)
async def async_step_authorize( async def async_step_authorize(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
@ -378,14 +371,15 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except AbortFlow as err: except AbortFlow as err:
return self.async_abort(reason=err.reason) return self.async_abort(reason=err.reason)
self._authorize_task = self.hass.async_create_task( self._authorize_task = self.hass.async_create_task(authorized_event.wait())
self._resume_flow_when_done(authorized_event.wait())
) if not self._authorize_task.done():
return self.async_show_progress( return self.async_show_progress(
step_id="authorize", progress_action="authorize" step_id="authorize",
progress_action="authorize",
progress_task=self._authorize_task,
) )
await self._authorize_task
self._authorize_task = None self._authorize_task = None
if self._unsub: if self._unsub:
self._unsub() self._unsub()

View File

@ -265,10 +265,7 @@ async def _test_common_success(
assert result["type"] == FlowResultType.SHOW_PROGRESS assert result["type"] == FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "provisioning" assert result["progress_action"] == "provisioning"
assert result["step_id"] == "do_provision" assert result["step_id"] == "do_provision"
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"]) result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result.get("description_placeholders") == placeholders assert result.get("description_placeholders") == placeholders
@ -321,10 +318,7 @@ async def _test_common_success_w_authorize(
assert result["progress_action"] == "authorize" assert result["progress_action"] == "authorize"
assert result["step_id"] == "authorize" assert result["step_id"] == "authorize"
mock_subscribe_state_updates.assert_awaited_once() mock_subscribe_state_updates.assert_awaited_once()
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision"
with patch( with patch(
f"{IMPROV_BLE}.config_flow.ImprovBLEClient.need_authorization", f"{IMPROV_BLE}.config_flow.ImprovBLEClient.need_authorization",
@ -337,10 +331,7 @@ async def _test_common_success_w_authorize(
assert result["type"] == FlowResultType.SHOW_PROGRESS assert result["type"] == FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "provisioning" assert result["progress_action"] == "provisioning"
assert result["step_id"] == "do_provision" assert result["step_id"] == "do_provision"
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"]) result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["description_placeholders"] == {"url": "http://blabla.local"} assert result["description_placeholders"] == {"url": "http://blabla.local"}
@ -578,10 +569,7 @@ async def _test_provision_error(hass: HomeAssistant, exc) -> None:
assert result["type"] == FlowResultType.SHOW_PROGRESS assert result["type"] == FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "provisioning" assert result["progress_action"] == "provisioning"
assert result["step_id"] == "do_provision" assert result["step_id"] == "do_provision"
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision_done"
return result["flow_id"] return result["flow_id"]