mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 19:57:07 +00:00
Improve calls to async_show_progress in google (#107788)
This commit is contained in:
parent
fbb6c1d0f0
commit
ddf3a36061
@ -1,6 +1,7 @@
|
||||
"""Config flow for Google integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any
|
||||
@ -68,6 +69,8 @@ class OAuth2FlowHandler(
|
||||
|
||||
DOMAIN = DOMAIN
|
||||
|
||||
_exchange_finished_task: asyncio.Task[bool] | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Set up instance."""
|
||||
super().__init__()
|
||||
@ -115,7 +118,7 @@ class OAuth2FlowHandler(
|
||||
if self._web_auth:
|
||||
return await super().async_step_auth(user_input)
|
||||
|
||||
if user_input is not None:
|
||||
if self._exchange_finished_task and self._exchange_finished_task.done():
|
||||
return self.async_show_progress_done(next_step_id="creation")
|
||||
|
||||
if not self._device_flow:
|
||||
@ -150,15 +153,16 @@ class OAuth2FlowHandler(
|
||||
return self.async_abort(reason="oauth_error")
|
||||
self._device_flow = device_flow
|
||||
|
||||
exchange_finished_evt = asyncio.Event()
|
||||
self._exchange_finished_task = self.hass.async_create_task(
|
||||
exchange_finished_evt.wait()
|
||||
)
|
||||
|
||||
def _exchange_finished() -> None:
|
||||
self.external_data = {
|
||||
DEVICE_AUTH_CREDS: device_flow.creds
|
||||
} # is None on timeout/expiration
|
||||
self.hass.async_create_task(
|
||||
self.hass.config_entries.flow.async_configure(
|
||||
flow_id=self.flow_id, user_input={}
|
||||
)
|
||||
)
|
||||
exchange_finished_evt.set()
|
||||
|
||||
device_flow.async_set_listener(_exchange_finished)
|
||||
device_flow.async_start_exchange()
|
||||
@ -170,6 +174,7 @@ class OAuth2FlowHandler(
|
||||
"user_code": self._device_flow.user_code,
|
||||
},
|
||||
progress_action="exchange",
|
||||
progress_task=self._exchange_finished_task,
|
||||
)
|
||||
|
||||
async def async_step_creation(
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import datetime
|
||||
from http import HTTPStatus
|
||||
@ -10,6 +11,7 @@ from unittest.mock import Mock, patch
|
||||
|
||||
from aiohttp.client_exceptions import ClientError
|
||||
from freezegun import freeze_time
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
from oauth2client.client import (
|
||||
DeviceFlowInfo,
|
||||
FlowExchangeError,
|
||||
@ -273,6 +275,7 @@ async def test_exchange_error(
|
||||
hass: HomeAssistant,
|
||||
mock_code_flow: Mock,
|
||||
mock_exchange: Mock,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test an error while exchanging the code for credentials."""
|
||||
await async_import_client_credential(
|
||||
@ -290,14 +293,19 @@ async def test_exchange_error(
|
||||
assert "url" in result["description_placeholders"]
|
||||
|
||||
# Run one tick to invoke the credential exchange check
|
||||
now = utcnow()
|
||||
step2_exchange_called = asyncio.Event()
|
||||
|
||||
def step2_exchange(*args, **kwargs):
|
||||
hass.loop.call_soon_threadsafe(step2_exchange_called.set)
|
||||
raise FlowExchangeError
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
|
||||
side_effect=FlowExchangeError(),
|
||||
side_effect=step2_exchange,
|
||||
):
|
||||
now += CODE_CHECK_ALARM_TIMEDELTA
|
||||
await fire_alarm(hass, now)
|
||||
await hass.async_block_till_done()
|
||||
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
|
||||
async_fire_time_changed(hass, utcnow())
|
||||
await step2_exchange_called.wait()
|
||||
|
||||
# Status has not updated, will retry
|
||||
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
|
||||
@ -308,8 +316,8 @@ async def test_exchange_error(
|
||||
with patch(
|
||||
"homeassistant.components.google.async_setup_entry", return_value=True
|
||||
) as mock_setup:
|
||||
now += CODE_CHECK_ALARM_TIMEDELTA
|
||||
await fire_alarm(hass, now)
|
||||
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
|
||||
async_fire_time_changed(hass, utcnow())
|
||||
await hass.async_block_till_done()
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
flow_id=result["flow_id"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user