Improve calls to async_show_progress in google (#107788)

This commit is contained in:
Erik Montnemery 2024-01-11 16:47:53 +01:00 committed by GitHub
parent fbb6c1d0f0
commit ddf3a36061
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 13 deletions

View File

@ -1,6 +1,7 @@
"""Config flow for Google integration.""" """Config flow for Google integration."""
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
@ -68,6 +69,8 @@ class OAuth2FlowHandler(
DOMAIN = DOMAIN DOMAIN = DOMAIN
_exchange_finished_task: asyncio.Task[bool] | None = None
def __init__(self) -> None: def __init__(self) -> None:
"""Set up instance.""" """Set up instance."""
super().__init__() super().__init__()
@ -115,7 +118,7 @@ class OAuth2FlowHandler(
if self._web_auth: if self._web_auth:
return await super().async_step_auth(user_input) return await super().async_step_auth(user_input)
if user_input is not None: if self._exchange_finished_task and self._exchange_finished_task.done():
return self.async_show_progress_done(next_step_id="creation") return self.async_show_progress_done(next_step_id="creation")
if not self._device_flow: if not self._device_flow:
@ -150,15 +153,16 @@ class OAuth2FlowHandler(
return self.async_abort(reason="oauth_error") return self.async_abort(reason="oauth_error")
self._device_flow = device_flow self._device_flow = device_flow
exchange_finished_evt = asyncio.Event()
self._exchange_finished_task = self.hass.async_create_task(
exchange_finished_evt.wait()
)
def _exchange_finished() -> None: def _exchange_finished() -> None:
self.external_data = { self.external_data = {
DEVICE_AUTH_CREDS: device_flow.creds DEVICE_AUTH_CREDS: device_flow.creds
} # is None on timeout/expiration } # is None on timeout/expiration
self.hass.async_create_task( exchange_finished_evt.set()
self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id, user_input={}
)
)
device_flow.async_set_listener(_exchange_finished) device_flow.async_set_listener(_exchange_finished)
device_flow.async_start_exchange() device_flow.async_start_exchange()
@ -170,6 +174,7 @@ class OAuth2FlowHandler(
"user_code": self._device_flow.user_code, "user_code": self._device_flow.user_code,
}, },
progress_action="exchange", progress_action="exchange",
progress_task=self._exchange_finished_task,
) )
async def async_step_creation( async def async_step_creation(

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
import datetime import datetime
from http import HTTPStatus from http import HTTPStatus
@ -10,6 +11,7 @@ from unittest.mock import Mock, patch
from aiohttp.client_exceptions import ClientError from aiohttp.client_exceptions import ClientError
from freezegun import freeze_time from freezegun import freeze_time
from freezegun.api import FrozenDateTimeFactory
from oauth2client.client import ( from oauth2client.client import (
DeviceFlowInfo, DeviceFlowInfo,
FlowExchangeError, FlowExchangeError,
@ -273,6 +275,7 @@ async def test_exchange_error(
hass: HomeAssistant, hass: HomeAssistant,
mock_code_flow: Mock, mock_code_flow: Mock,
mock_exchange: Mock, mock_exchange: Mock,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Test an error while exchanging the code for credentials.""" """Test an error while exchanging the code for credentials."""
await async_import_client_credential( await async_import_client_credential(
@ -290,14 +293,19 @@ async def test_exchange_error(
assert "url" in result["description_placeholders"] assert "url" in result["description_placeholders"]
# Run one tick to invoke the credential exchange check # Run one tick to invoke the credential exchange check
now = utcnow() step2_exchange_called = asyncio.Event()
def step2_exchange(*args, **kwargs):
hass.loop.call_soon_threadsafe(step2_exchange_called.set)
raise FlowExchangeError
with patch( with patch(
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange", "homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
side_effect=FlowExchangeError(), side_effect=step2_exchange,
): ):
now += CODE_CHECK_ALARM_TIMEDELTA freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
await fire_alarm(hass, now) async_fire_time_changed(hass, utcnow())
await hass.async_block_till_done() await step2_exchange_called.wait()
# Status has not updated, will retry # Status has not updated, will retry
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"]) result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
@ -308,8 +316,8 @@ async def test_exchange_error(
with patch( with patch(
"homeassistant.components.google.async_setup_entry", return_value=True "homeassistant.components.google.async_setup_entry", return_value=True
) as mock_setup: ) as mock_setup:
now += CODE_CHECK_ALARM_TIMEDELTA freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
await fire_alarm(hass, now) async_fire_time_changed(hass, utcnow())
await hass.async_block_till_done() await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
flow_id=result["flow_id"] flow_id=result["flow_id"]