Improve calls to async_show_progress in octoprint (#107792)

This commit is contained in:
Erik Montnemery 2024-01-14 11:06:35 +01:00 committed by GitHub
parent 1c9764bc44
commit 7fc3f8e473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 34 deletions

View File

@ -53,12 +53,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
api_key_task: asyncio.Task[None] | None = None api_key_task: asyncio.Task[None] | None = None
discovery_schema: vol.Schema | None = None
_reauth_data: dict[str, Any] | None = None _reauth_data: dict[str, Any] | None = None
_user_input: dict[str, Any] | None = None
def __init__(self) -> None: def __init__(self) -> None:
"""Handle a config flow for OctoPrint.""" """Handle a config flow for OctoPrint."""
self.discovery_schema = None
self._user_input = None
self._sessions: list[aiohttp.ClientSession] = [] self._sessions: list[aiohttp.ClientSession] = []
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
@ -97,17 +97,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
), ),
) )
self.api_key_task = None self._user_input = user_input
return await self.async_step_get_api_key(user_input) return await self.async_step_get_api_key()
async def async_step_get_api_key(self, user_input): async def async_step_get_api_key(self, user_input=None):
"""Get an Application Api Key.""" """Get an Application Api Key."""
if not self.api_key_task: if not self.api_key_task:
self.api_key_task = self.hass.async_create_task( self.api_key_task = self.hass.async_create_task(self._async_get_auth_key())
self._async_get_auth_key(user_input) if not self.api_key_task.done():
)
return self.async_show_progress( return self.async_show_progress(
step_id="get_api_key", progress_action="get_api_key" step_id="get_api_key",
progress_action="get_api_key",
progress_task=self.api_key_task,
) )
try: try:
@ -118,9 +119,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Failed to get an application key : %s", err) _LOGGER.exception("Failed to get an application key : %s", err)
return self.async_show_progress_done(next_step_id="auth_failed") return self.async_show_progress_done(next_step_id="auth_failed")
finally:
self.api_key_task = None
# store this off here to pick back up in the user step
self._user_input = user_input
return self.async_show_progress_done(next_step_id="user") return self.async_show_progress_done(next_step_id="user")
async def _finish_config(self, user_input: dict): async def _finish_config(self, user_input: dict):
@ -238,25 +239,17 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
), ),
) )
self.api_key_task = None
self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME] self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME]
return await self.async_step_get_api_key(self._reauth_data) self._user_input = self._reauth_data
return await self.async_step_get_api_key()
async def _async_get_auth_key(self, user_input: dict): async def _async_get_auth_key(self):
"""Get application api key.""" """Get application api key."""
octoprint = self._get_octoprint_client(user_input) octoprint = self._get_octoprint_client(self._user_input)
try: self._user_input[CONF_API_KEY] = await octoprint.request_app_key(
user_input[CONF_API_KEY] = await octoprint.request_app_key( "Home Assistant", self._user_input[CONF_USERNAME], 300
"Home Assistant", user_input[CONF_USERNAME], 300
)
finally:
# Continue the flow after show progress when the task is done.
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id, user_input=user_input
)
) )
def _get_octoprint_client(self, user_input: dict) -> OctoprintClient: def _get_octoprint_client(self, user_input: dict) -> OctoprintClient:

View File

@ -95,8 +95,9 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
) )
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
with patch( with patch(
"pyoctoprintapi.OctoprintClient.get_discovery_info", "pyoctoprintapi.OctoprintClient.get_discovery_info",
side_effect=ApiError, side_effect=ApiError,
@ -144,8 +145,9 @@ async def test_form_unknown_exception(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
) )
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
with patch( with patch(
"pyoctoprintapi.OctoprintClient.get_discovery_info", "pyoctoprintapi.OctoprintClient.get_discovery_info",
side_effect=Exception, side_effect=Exception,
@ -203,7 +205,7 @@ async def test_show_zerconf_form(hass: HomeAssistant) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result["type"] == "progress_done" assert result["type"] == "progress"
with patch( with patch(
"pyoctoprintapi.OctoprintClient.get_server_info", "pyoctoprintapi.OctoprintClient.get_server_info",
@ -269,7 +271,7 @@ async def test_show_ssdp_form(hass: HomeAssistant) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result["type"] == "progress_done" assert result["type"] == "progress"
with patch( with patch(
"pyoctoprintapi.OctoprintClient.get_server_info", "pyoctoprintapi.OctoprintClient.get_server_info",
@ -390,10 +392,11 @@ async def test_failed_auth(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
) )
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"]) result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "abort" assert result["type"] == "abort"
assert result["reason"] == "auth_failed" assert result["reason"] == "auth_failed"
@ -421,10 +424,11 @@ async def test_failed_auth_unexpected_error(hass: HomeAssistant) -> None:
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
) )
await hass.async_block_till_done()
assert result["type"] == "progress"
assert result["type"] == "progress_done"
result = await hass.config_entries.flow.async_configure(result["flow_id"]) result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == "abort" assert result["type"] == "abort"
assert result["reason"] == "auth_failed" assert result["reason"] == "auth_failed"