Improve calls to async_show_progress in google (#107788)

This commit is contained in:
Erik Montnemery 2024-01-11 16:47:53 +01:00 committed by GitHub
parent fbb6c1d0f0
commit ddf3a36061
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 13 deletions

View File

@ -1,6 +1,7 @@
"""Config flow for Google integration."""
from __future__ import annotations
import asyncio
from collections.abc import Mapping
import logging
from typing import Any
@ -68,6 +69,8 @@ class OAuth2FlowHandler(
DOMAIN = DOMAIN
_exchange_finished_task: asyncio.Task[bool] | None = None
def __init__(self) -> None:
"""Set up instance."""
super().__init__()
@ -115,7 +118,7 @@ class OAuth2FlowHandler(
if self._web_auth:
return await super().async_step_auth(user_input)
if user_input is not None:
if self._exchange_finished_task and self._exchange_finished_task.done():
return self.async_show_progress_done(next_step_id="creation")
if not self._device_flow:
@ -150,15 +153,16 @@ class OAuth2FlowHandler(
return self.async_abort(reason="oauth_error")
self._device_flow = device_flow
exchange_finished_evt = asyncio.Event()
self._exchange_finished_task = self.hass.async_create_task(
exchange_finished_evt.wait()
)
def _exchange_finished() -> None:
self.external_data = {
DEVICE_AUTH_CREDS: device_flow.creds
} # is None on timeout/expiration
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id, user_input={}
)
)
exchange_finished_evt.set()
device_flow.async_set_listener(_exchange_finished)
device_flow.async_start_exchange()
@ -170,6 +174,7 @@ class OAuth2FlowHandler(
"user_code": self._device_flow.user_code,
},
progress_action="exchange",
progress_task=self._exchange_finished_task,
)
async def async_step_creation(

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import datetime
from http import HTTPStatus
@ -10,6 +11,7 @@ from unittest.mock import Mock, patch
from aiohttp.client_exceptions import ClientError
from freezegun import freeze_time
from freezegun.api import FrozenDateTimeFactory
from oauth2client.client import (
DeviceFlowInfo,
FlowExchangeError,
@ -273,6 +275,7 @@ async def test_exchange_error(
hass: HomeAssistant,
mock_code_flow: Mock,
mock_exchange: Mock,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test an error while exchanging the code for credentials."""
await async_import_client_credential(
@ -290,14 +293,19 @@ async def test_exchange_error(
assert "url" in result["description_placeholders"]
# Run one tick to invoke the credential exchange check
now = utcnow()
step2_exchange_called = asyncio.Event()
def step2_exchange(*args, **kwargs):
hass.loop.call_soon_threadsafe(step2_exchange_called.set)
raise FlowExchangeError
with patch(
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
side_effect=FlowExchangeError(),
side_effect=step2_exchange,
):
now += CODE_CHECK_ALARM_TIMEDELTA
await fire_alarm(hass, now)
await hass.async_block_till_done()
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
async_fire_time_changed(hass, utcnow())
await step2_exchange_called.wait()
# Status has not updated, will retry
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
@ -308,8 +316,8 @@ async def test_exchange_error(
with patch(
"homeassistant.components.google.async_setup_entry", return_value=True
) as mock_setup:
now += CODE_CHECK_ALARM_TIMEDELTA
await fire_alarm(hass, now)
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
async_fire_time_changed(hass, utcnow())
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(
flow_id=result["flow_id"]