mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Improve calls to async_show_progress in google (#107788)
This commit is contained in:
parent
fbb6c1d0f0
commit
ddf3a36061
@ -1,6 +1,7 @@
|
|||||||
"""Config flow for Google integration."""
|
"""Config flow for Google integration."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -68,6 +69,8 @@ class OAuth2FlowHandler(
|
|||||||
|
|
||||||
DOMAIN = DOMAIN
|
DOMAIN = DOMAIN
|
||||||
|
|
||||||
|
_exchange_finished_task: asyncio.Task[bool] | None = None
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Set up instance."""
|
"""Set up instance."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -115,7 +118,7 @@ class OAuth2FlowHandler(
|
|||||||
if self._web_auth:
|
if self._web_auth:
|
||||||
return await super().async_step_auth(user_input)
|
return await super().async_step_auth(user_input)
|
||||||
|
|
||||||
if user_input is not None:
|
if self._exchange_finished_task and self._exchange_finished_task.done():
|
||||||
return self.async_show_progress_done(next_step_id="creation")
|
return self.async_show_progress_done(next_step_id="creation")
|
||||||
|
|
||||||
if not self._device_flow:
|
if not self._device_flow:
|
||||||
@ -150,15 +153,16 @@ class OAuth2FlowHandler(
|
|||||||
return self.async_abort(reason="oauth_error")
|
return self.async_abort(reason="oauth_error")
|
||||||
self._device_flow = device_flow
|
self._device_flow = device_flow
|
||||||
|
|
||||||
|
exchange_finished_evt = asyncio.Event()
|
||||||
|
self._exchange_finished_task = self.hass.async_create_task(
|
||||||
|
exchange_finished_evt.wait()
|
||||||
|
)
|
||||||
|
|
||||||
def _exchange_finished() -> None:
|
def _exchange_finished() -> None:
|
||||||
self.external_data = {
|
self.external_data = {
|
||||||
DEVICE_AUTH_CREDS: device_flow.creds
|
DEVICE_AUTH_CREDS: device_flow.creds
|
||||||
} # is None on timeout/expiration
|
} # is None on timeout/expiration
|
||||||
self.hass.async_create_task(
|
exchange_finished_evt.set()
|
||||||
self.hass.config_entries.flow.async_configure(
|
|
||||||
flow_id=self.flow_id, user_input={}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
device_flow.async_set_listener(_exchange_finished)
|
device_flow.async_set_listener(_exchange_finished)
|
||||||
device_flow.async_start_exchange()
|
device_flow.async_start_exchange()
|
||||||
@ -170,6 +174,7 @@ class OAuth2FlowHandler(
|
|||||||
"user_code": self._device_flow.user_code,
|
"user_code": self._device_flow.user_code,
|
||||||
},
|
},
|
||||||
progress_action="exchange",
|
progress_action="exchange",
|
||||||
|
progress_task=self._exchange_finished_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_creation(
|
async def async_step_creation(
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import datetime
|
import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -10,6 +11,7 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
from aiohttp.client_exceptions import ClientError
|
from aiohttp.client_exceptions import ClientError
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
from oauth2client.client import (
|
from oauth2client.client import (
|
||||||
DeviceFlowInfo,
|
DeviceFlowInfo,
|
||||||
FlowExchangeError,
|
FlowExchangeError,
|
||||||
@ -273,6 +275,7 @@ async def test_exchange_error(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_code_flow: Mock,
|
mock_code_flow: Mock,
|
||||||
mock_exchange: Mock,
|
mock_exchange: Mock,
|
||||||
|
freezer: FrozenDateTimeFactory,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test an error while exchanging the code for credentials."""
|
"""Test an error while exchanging the code for credentials."""
|
||||||
await async_import_client_credential(
|
await async_import_client_credential(
|
||||||
@ -290,14 +293,19 @@ async def test_exchange_error(
|
|||||||
assert "url" in result["description_placeholders"]
|
assert "url" in result["description_placeholders"]
|
||||||
|
|
||||||
# Run one tick to invoke the credential exchange check
|
# Run one tick to invoke the credential exchange check
|
||||||
now = utcnow()
|
step2_exchange_called = asyncio.Event()
|
||||||
|
|
||||||
|
def step2_exchange(*args, **kwargs):
|
||||||
|
hass.loop.call_soon_threadsafe(step2_exchange_called.set)
|
||||||
|
raise FlowExchangeError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
|
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
|
||||||
side_effect=FlowExchangeError(),
|
side_effect=step2_exchange,
|
||||||
):
|
):
|
||||||
now += CODE_CHECK_ALARM_TIMEDELTA
|
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
|
||||||
await fire_alarm(hass, now)
|
async_fire_time_changed(hass, utcnow())
|
||||||
await hass.async_block_till_done()
|
await step2_exchange_called.wait()
|
||||||
|
|
||||||
# Status has not updated, will retry
|
# Status has not updated, will retry
|
||||||
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
|
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
|
||||||
@ -308,8 +316,8 @@ async def test_exchange_error(
|
|||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.google.async_setup_entry", return_value=True
|
"homeassistant.components.google.async_setup_entry", return_value=True
|
||||||
) as mock_setup:
|
) as mock_setup:
|
||||||
now += CODE_CHECK_ALARM_TIMEDELTA
|
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
|
||||||
await fire_alarm(hass, now)
|
async_fire_time_changed(hass, utcnow())
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
result = await hass.config_entries.flow.async_configure(
|
result = await hass.config_entries.flow.async_configure(
|
||||||
flow_id=result["flow_id"]
|
flow_id=result["flow_id"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user