mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 08:17:08 +00:00
Fix unexpected exception in Google Calendar OAuth exchange (#73963)
This commit is contained in:
parent
1d185388a9
commit
02d1676301
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@ -19,9 +18,12 @@ from oauth2client.client import (
|
|||||||
|
|
||||||
from homeassistant.components.application_credentials import AuthImplementation
|
from homeassistant.components.application_credentials import AuthImplementation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
from homeassistant.helpers.event import async_track_time_interval
|
from homeassistant.helpers.event import (
|
||||||
|
async_track_point_in_utc_time,
|
||||||
|
async_track_time_interval,
|
||||||
|
)
|
||||||
from homeassistant.util import dt
|
from homeassistant.util import dt
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -76,6 +78,9 @@ class DeviceFlow:
|
|||||||
self._oauth_flow = oauth_flow
|
self._oauth_flow = oauth_flow
|
||||||
self._device_flow_info: DeviceFlowInfo = device_flow_info
|
self._device_flow_info: DeviceFlowInfo = device_flow_info
|
||||||
self._exchange_task_unsub: CALLBACK_TYPE | None = None
|
self._exchange_task_unsub: CALLBACK_TYPE | None = None
|
||||||
|
self._timeout_unsub: CALLBACK_TYPE | None = None
|
||||||
|
self._listener: CALLBACK_TYPE | None = None
|
||||||
|
self._creds: Credentials | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def verification_url(self) -> str:
|
def verification_url(self) -> str:
|
||||||
@ -87,15 +92,22 @@ class DeviceFlow:
|
|||||||
"""Return the code that the user should enter at the verification url."""
|
"""Return the code that the user should enter at the verification url."""
|
||||||
return self._device_flow_info.user_code # type: ignore[no-any-return]
|
return self._device_flow_info.user_code # type: ignore[no-any-return]
|
||||||
|
|
||||||
async def start_exchange_task(
|
@callback
|
||||||
self, finished_cb: Callable[[Credentials | None], Awaitable[None]]
|
def async_set_listener(
|
||||||
|
self,
|
||||||
|
update_callback: CALLBACK_TYPE,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start the device auth exchange flow polling.
|
"""Invoke the update callback when the exchange finishes or on timeout."""
|
||||||
|
self._listener = update_callback
|
||||||
|
|
||||||
The callback is invoked with the valid credentials or with None on timeout.
|
@property
|
||||||
"""
|
def creds(self) -> Credentials | None:
|
||||||
|
"""Return result of exchange step or None on timeout."""
|
||||||
|
return self._creds
|
||||||
|
|
||||||
|
def async_start_exchange(self) -> None:
|
||||||
|
"""Start the device auth exchange flow polling."""
|
||||||
_LOGGER.debug("Starting exchange flow")
|
_LOGGER.debug("Starting exchange flow")
|
||||||
assert not self._exchange_task_unsub
|
|
||||||
max_timeout = dt.utcnow() + datetime.timedelta(seconds=EXCHANGE_TIMEOUT_SECONDS)
|
max_timeout = dt.utcnow() + datetime.timedelta(seconds=EXCHANGE_TIMEOUT_SECONDS)
|
||||||
# For some reason, oauth.step1_get_device_and_user_codes() returns a datetime
|
# For some reason, oauth.step1_get_device_and_user_codes() returns a datetime
|
||||||
# object without tzinfo. For the comparison below to work, it needs one.
|
# object without tzinfo. For the comparison below to work, it needs one.
|
||||||
@ -104,31 +116,40 @@ class DeviceFlow:
|
|||||||
)
|
)
|
||||||
expiration_time = min(user_code_expiry, max_timeout)
|
expiration_time = min(user_code_expiry, max_timeout)
|
||||||
|
|
||||||
def _exchange() -> Credentials:
|
|
||||||
return self._oauth_flow.step2_exchange(
|
|
||||||
device_flow_info=self._device_flow_info
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _poll_attempt(now: datetime.datetime) -> None:
|
|
||||||
assert self._exchange_task_unsub
|
|
||||||
_LOGGER.debug("Attempting OAuth code exchange")
|
|
||||||
# Note: The callback is invoked with None when the device code has expired
|
|
||||||
creds: Credentials | None = None
|
|
||||||
if now < expiration_time:
|
|
||||||
try:
|
|
||||||
creds = await self._hass.async_add_executor_job(_exchange)
|
|
||||||
except FlowExchangeError:
|
|
||||||
_LOGGER.debug("Token not yet ready; trying again later")
|
|
||||||
return
|
|
||||||
self._exchange_task_unsub()
|
|
||||||
self._exchange_task_unsub = None
|
|
||||||
await finished_cb(creds)
|
|
||||||
|
|
||||||
self._exchange_task_unsub = async_track_time_interval(
|
self._exchange_task_unsub = async_track_time_interval(
|
||||||
self._hass,
|
self._hass,
|
||||||
_poll_attempt,
|
self._async_poll_attempt,
|
||||||
datetime.timedelta(seconds=self._device_flow_info.interval),
|
datetime.timedelta(seconds=self._device_flow_info.interval),
|
||||||
)
|
)
|
||||||
|
self._timeout_unsub = async_track_point_in_utc_time(
|
||||||
|
self._hass, self._async_timeout, expiration_time
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_poll_attempt(self, now: datetime.datetime) -> None:
|
||||||
|
_LOGGER.debug("Attempting OAuth code exchange")
|
||||||
|
try:
|
||||||
|
self._creds = await self._hass.async_add_executor_job(self._exchange)
|
||||||
|
except FlowExchangeError:
|
||||||
|
_LOGGER.debug("Token not yet ready; trying again later")
|
||||||
|
return
|
||||||
|
self._finish()
|
||||||
|
|
||||||
|
def _exchange(self) -> Credentials:
|
||||||
|
return self._oauth_flow.step2_exchange(device_flow_info=self._device_flow_info)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_timeout(self, now: datetime.datetime) -> None:
|
||||||
|
_LOGGER.debug("OAuth token exchange timeout")
|
||||||
|
self._finish()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _finish(self) -> None:
|
||||||
|
if self._exchange_task_unsub:
|
||||||
|
self._exchange_task_unsub()
|
||||||
|
if self._timeout_unsub:
|
||||||
|
self._timeout_unsub()
|
||||||
|
if self._listener:
|
||||||
|
self._listener()
|
||||||
|
|
||||||
|
|
||||||
def get_feature_access(
|
def get_feature_access(
|
||||||
|
@ -7,7 +7,6 @@ from typing import Any
|
|||||||
|
|
||||||
from gcal_sync.api import GoogleCalendarService
|
from gcal_sync.api import GoogleCalendarService
|
||||||
from gcal_sync.exceptions import ApiException
|
from gcal_sync.exceptions import ApiException
|
||||||
from oauth2client.client import Credentials
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
@ -97,9 +96,9 @@ class OAuth2FlowHandler(
|
|||||||
return self.async_abort(reason="oauth_error")
|
return self.async_abort(reason="oauth_error")
|
||||||
self._device_flow = device_flow
|
self._device_flow = device_flow
|
||||||
|
|
||||||
async def _exchange_finished(creds: Credentials | None) -> None:
|
def _exchange_finished() -> None:
|
||||||
self.external_data = {
|
self.external_data = {
|
||||||
DEVICE_AUTH_CREDS: creds
|
DEVICE_AUTH_CREDS: device_flow.creds
|
||||||
} # is None on timeout/expiration
|
} # is None on timeout/expiration
|
||||||
self.hass.async_create_task(
|
self.hass.async_create_task(
|
||||||
self.hass.config_entries.flow.async_configure(
|
self.hass.config_entries.flow.async_configure(
|
||||||
@ -107,7 +106,8 @@ class OAuth2FlowHandler(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await device_flow.start_exchange_task(_exchange_finished)
|
device_flow.async_set_listener(_exchange_finished)
|
||||||
|
device_flow.async_start_exchange()
|
||||||
|
|
||||||
return self.async_show_progress(
|
return self.async_show_progress(
|
||||||
step_id="auth",
|
step_id="auth",
|
||||||
|
@ -10,6 +10,7 @@ from unittest.mock import Mock, patch
|
|||||||
from aiohttp.client_exceptions import ClientError
|
from aiohttp.client_exceptions import ClientError
|
||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
from oauth2client.client import (
|
from oauth2client.client import (
|
||||||
|
DeviceFlowInfo,
|
||||||
FlowExchangeError,
|
FlowExchangeError,
|
||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
OAuth2DeviceCodeError,
|
OAuth2DeviceCodeError,
|
||||||
@ -59,10 +60,17 @@ async def mock_code_flow(
|
|||||||
) -> YieldFixture[Mock]:
|
) -> YieldFixture[Mock]:
|
||||||
"""Fixture for initiating OAuth flow."""
|
"""Fixture for initiating OAuth flow."""
|
||||||
with patch(
|
with patch(
|
||||||
"oauth2client.client.OAuth2WebServerFlow.step1_get_device_and_user_codes",
|
"homeassistant.components.google.api.OAuth2WebServerFlow.step1_get_device_and_user_codes",
|
||||||
) as mock_flow:
|
) as mock_flow:
|
||||||
mock_flow.return_value.user_code_expiry = utcnow() + code_expiration_delta
|
mock_flow.return_value = DeviceFlowInfo.FromResponse(
|
||||||
mock_flow.return_value.interval = CODE_CHECK_INTERVAL
|
{
|
||||||
|
"device_code": "4/4-GMMhmHCXhWEzkobqIHGG_EnNYYsAkukHspeYUk9E8",
|
||||||
|
"user_code": "GQVQ-JKEC",
|
||||||
|
"verification_url": "https://www.google.com/device",
|
||||||
|
"expires_in": code_expiration_delta.total_seconds(),
|
||||||
|
"interval": CODE_CHECK_INTERVAL,
|
||||||
|
}
|
||||||
|
)
|
||||||
yield mock_flow
|
yield mock_flow
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +78,8 @@ async def mock_code_flow(
|
|||||||
async def mock_exchange(creds: OAuth2Credentials) -> YieldFixture[Mock]:
|
async def mock_exchange(creds: OAuth2Credentials) -> YieldFixture[Mock]:
|
||||||
"""Fixture for mocking out the exchange for credentials."""
|
"""Fixture for mocking out the exchange for credentials."""
|
||||||
with patch(
|
with patch(
|
||||||
"oauth2client.client.OAuth2WebServerFlow.step2_exchange", return_value=creds
|
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
|
||||||
|
return_value=creds,
|
||||||
) as mock:
|
) as mock:
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
@ -108,7 +117,6 @@ async def fire_alarm(hass, point_in_time):
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.freeze_time("2022-06-03 15:19:59-00:00")
|
|
||||||
async def test_full_flow_yaml_creds(
|
async def test_full_flow_yaml_creds(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_code_flow: Mock,
|
mock_code_flow: Mock,
|
||||||
@ -131,9 +139,8 @@ async def test_full_flow_yaml_creds(
|
|||||||
"homeassistant.components.google.async_setup_entry", return_value=True
|
"homeassistant.components.google.async_setup_entry", return_value=True
|
||||||
) as mock_setup:
|
) as mock_setup:
|
||||||
# Run one tick to invoke the credential exchange check
|
# Run one tick to invoke the credential exchange check
|
||||||
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
|
now = utcnow()
|
||||||
await fire_alarm(hass, datetime.datetime.utcnow())
|
await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA)
|
||||||
await hass.async_block_till_done()
|
|
||||||
result = await hass.config_entries.flow.async_configure(
|
result = await hass.config_entries.flow.async_configure(
|
||||||
flow_id=result["flow_id"]
|
flow_id=result["flow_id"]
|
||||||
)
|
)
|
||||||
@ -143,11 +150,12 @@ async def test_full_flow_yaml_creds(
|
|||||||
assert "data" in result
|
assert "data" in result
|
||||||
data = result["data"]
|
data = result["data"]
|
||||||
assert "token" in data
|
assert "token" in data
|
||||||
|
assert 0 < data["token"]["expires_in"] <= 60 * 60
|
||||||
assert (
|
assert (
|
||||||
data["token"]["expires_in"]
|
datetime.datetime.now().timestamp()
|
||||||
== 60 * 60 - CODE_CHECK_ALARM_TIMEDELTA.total_seconds()
|
<= data["token"]["expires_at"]
|
||||||
|
< (datetime.datetime.now() + datetime.timedelta(days=8)).timestamp()
|
||||||
)
|
)
|
||||||
assert data["token"]["expires_at"] == 1654273199.0
|
|
||||||
data["token"].pop("expires_at")
|
data["token"].pop("expires_at")
|
||||||
data["token"].pop("expires_in")
|
data["token"].pop("expires_in")
|
||||||
assert data == {
|
assert data == {
|
||||||
@ -238,7 +246,7 @@ async def test_code_error(
|
|||||||
assert await component_setup()
|
assert await component_setup()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"oauth2client.client.OAuth2WebServerFlow.step1_get_device_and_user_codes",
|
"homeassistant.components.google.api.OAuth2WebServerFlow.step1_get_device_and_user_codes",
|
||||||
side_effect=OAuth2DeviceCodeError("Test Failure"),
|
side_effect=OAuth2DeviceCodeError("Test Failure"),
|
||||||
):
|
):
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
@ -248,13 +256,13 @@ async def test_code_error(
|
|||||||
assert result.get("reason") == "oauth_error"
|
assert result.get("reason") == "oauth_error"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("code_expiration_delta", [datetime.timedelta(minutes=-5)])
|
@pytest.mark.parametrize("code_expiration_delta", [datetime.timedelta(seconds=50)])
|
||||||
async def test_expired_after_exchange(
|
async def test_expired_after_exchange(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_code_flow: Mock,
|
mock_code_flow: Mock,
|
||||||
component_setup: ComponentSetup,
|
component_setup: ComponentSetup,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful creds setup."""
|
"""Test credential exchange expires."""
|
||||||
assert await component_setup()
|
assert await component_setup()
|
||||||
|
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
@ -265,10 +273,14 @@ async def test_expired_after_exchange(
|
|||||||
assert "description_placeholders" in result
|
assert "description_placeholders" in result
|
||||||
assert "url" in result["description_placeholders"]
|
assert "url" in result["description_placeholders"]
|
||||||
|
|
||||||
# Run one tick to invoke the credential exchange check
|
# Fail first attempt then advance clock past exchange timeout
|
||||||
now = utcnow()
|
with patch(
|
||||||
await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA)
|
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
|
||||||
await hass.async_block_till_done()
|
side_effect=FlowExchangeError(),
|
||||||
|
):
|
||||||
|
now = utcnow()
|
||||||
|
await fire_alarm(hass, now + datetime.timedelta(seconds=65))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
|
result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
|
||||||
assert result.get("type") == "abort"
|
assert result.get("type") == "abort"
|
||||||
@ -295,7 +307,7 @@ async def test_exchange_error(
|
|||||||
# Run one tick to invoke the credential exchange check
|
# Run one tick to invoke the credential exchange check
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
with patch(
|
with patch(
|
||||||
"oauth2client.client.OAuth2WebServerFlow.step2_exchange",
|
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
|
||||||
side_effect=FlowExchangeError(),
|
side_effect=FlowExchangeError(),
|
||||||
):
|
):
|
||||||
now += CODE_CHECK_ALARM_TIMEDELTA
|
now += CODE_CHECK_ALARM_TIMEDELTA
|
||||||
|
Loading…
x
Reference in New Issue
Block a user