Improve calls to async_show_progress in octoprint (#107792)

This commit is contained in:
Erik Montnemery 2024-01-14 11:06:35 +01:00 committed by GitHub
parent 1c9764bc44
commit 7fc3f8e473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 34 deletions

View File

@ -53,12 +53,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1
api_key_task: asyncio.Task[None] | None = None
discovery_schema: vol.Schema | None = None
_reauth_data: dict[str, Any] | None = None
_user_input: dict[str, Any] | None = None
def __init__(self) -> None:
"""Handle a config flow for OctoPrint."""
self.discovery_schema = None
self._user_input = None
self._sessions: list[aiohttp.ClientSession] = []
async def async_step_user(self, user_input=None):
@ -97,17 +97,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
),
)
self.api_key_task = None
return await self.async_step_get_api_key(user_input)
self._user_input = user_input
return await self.async_step_get_api_key()
async def async_step_get_api_key(self, user_input):
async def async_step_get_api_key(self, user_input=None):
"""Get an Application Api Key."""
if not self.api_key_task:
self.api_key_task = self.hass.async_create_task(
self._async_get_auth_key(user_input)
)
self.api_key_task = self.hass.async_create_task(self._async_get_auth_key())
if not self.api_key_task.done():
return self.async_show_progress(
step_id="get_api_key", progress_action="get_api_key"
step_id="get_api_key",
progress_action="get_api_key",
progress_task=self.api_key_task,
)
try:
@ -118,9 +119,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Failed to get an application key : %s", err)
return self.async_show_progress_done(next_step_id="auth_failed")
finally:
self.api_key_task = None
# store this off here to pick back up in the user step
self._user_input = user_input
return self.async_show_progress_done(next_step_id="user")
async def _finish_config(self, user_input: dict):
@ -238,26 +239,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
),
)
self.api_key_task = None
self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME]
return await self.async_step_get_api_key(self._reauth_data)
self._user_input = self._reauth_data
return await self.async_step_get_api_key()
async def _async_get_auth_key(self, user_input: dict):
async def _async_get_auth_key(self):
"""Get application api key."""
octoprint = self._get_octoprint_client(user_input)
octoprint = self._get_octoprint_client(self._user_input)
try:
user_input[CONF_API_KEY] = await octoprint.request_app_key(
"Home Assistant", user_input[CONF_USERNAME], 300
)
finally:
# Continue the flow after show progress when the task is done.
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id, user_input=user_input
)
)
self._user_input[CONF_API_KEY] = await octoprint.request_app_key(
"Home Assistant", self._user_input[CONF_USERNAME], 300
)
def _get_octoprint_client(self, user_input: dict) -> OctoprintClient:
"""Build an octoprint client from the user_input."""

View File

@ -95,8 +95,9 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
with patch(
"pyoctoprintapi.OctoprintClient.get_discovery_info",
side_effect=ApiError,
@ -144,8 +145,9 @@ async def test_form_unknown_exception(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
with patch(
"pyoctoprintapi.OctoprintClient.get_discovery_info",
side_effect=Exception,
@ -203,7 +205,7 @@ async def test_show_zerconf_form(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
assert result["type"] == "progress_done"
assert result["type"] == "progress"
with patch(
"pyoctoprintapi.OctoprintClient.get_server_info",
@ -269,7 +271,7 @@ async def test_show_ssdp_form(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
assert result["type"] == "progress_done"
assert result["type"] == "progress"
with patch(
"pyoctoprintapi.OctoprintClient.get_server_info",
@ -390,10 +392,11 @@ async def test_failed_auth(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "abort"
assert result["reason"] == "auth_failed"
@ -421,10 +424,11 @@ async def test_failed_auth_unexpected_error(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
)
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "abort"
assert result["reason"] == "auth_failed"