From b1f3068b41dc06583ac14371277f85f29bf3a736 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 9 Feb 2025 09:31:18 -0800 Subject: [PATCH] Refresh the nest authentication token on integration start before invoking the pub/sub subsciber (#138003) * Refresh the nest authentication token on integration start before invoking the pub/sub subscriber * Apply suggestions from code review --------- Co-authored-by: Paulus Schoutsen --- homeassistant/components/nest/__init__.py | 11 +++- homeassistant/components/nest/api.py | 18 ++++-- tests/components/nest/test_api.py | 77 ----------------------- 3 files changed, 22 insertions(+), 84 deletions(-) diff --git a/homeassistant/components/nest/__init__.py b/homeassistant/components/nest/__init__.py index 8adc0e4f714..67c14bbf544 100644 --- a/homeassistant/components/nest/__init__.py +++ b/homeassistant/components/nest/__init__.py @@ -198,7 +198,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: NestConfigEntry) -> bool entry, unique_id=entry.data[CONF_PROJECT_ID] ) - subscriber = await api.new_subscriber(hass, entry) + auth = await api.new_auth(hass, entry) + try: + await auth.async_get_access_token() + except AuthException as err: + raise ConfigEntryAuthFailed(f"Authentication error: {err!s}") from err + except ConfigurationException as err: + _LOGGER.error("Configuration error: %s", err) + return False + + subscriber = await api.new_subscriber(hass, entry, auth) if not subscriber: return False # Keep media for last N events in memory diff --git a/homeassistant/components/nest/api.py b/homeassistant/components/nest/api.py index e86e326b1c2..727b126dda4 100644 --- a/homeassistant/components/nest/api.py +++ b/homeassistant/components/nest/api.py @@ -101,9 +101,7 @@ class AccessTokenAuthImpl(AbstractAuth): ) -async def new_subscriber( - hass: HomeAssistant, entry: NestConfigEntry -) -> GoogleNestSubscriber | None: +async def new_auth(hass: HomeAssistant, entry: NestConfigEntry) -> AbstractAuth: """Create a GoogleNestSubscriber.""" implementation = ( await config_entry_oauth2_flow.async_get_config_entry_implementation( @@ -114,14 +112,22 @@ async def new_subscriber( implementation, config_entry_oauth2_flow.LocalOAuth2Implementation ): raise TypeError(f"Unexpected auth implementation {implementation}") - if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None: - subscription_name = entry.data[CONF_SUBSCRIBER_ID] - auth = AsyncConfigEntryAuth( + return AsyncConfigEntryAuth( aiohttp_client.async_get_clientsession(hass), config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation), implementation.client_id, implementation.client_secret, ) + + +async def new_subscriber( + hass: HomeAssistant, + entry: NestConfigEntry, + auth: AbstractAuth, +) -> GoogleNestSubscriber: + """Create a GoogleNestSubscriber.""" + if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None: + subscription_name = entry.data[CONF_SUBSCRIBER_ID] return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscription_name) diff --git a/tests/components/nest/test_api.py b/tests/components/nest/test_api.py index 98c3e06cfb8..1a5c4d63dba 100644 --- a/tests/components/nest/test_api.py +++ b/tests/components/nest/test_api.py @@ -89,80 +89,3 @@ async def test_auth( assert creds.client_id == CLIENT_ID assert creds.client_secret == CLIENT_SECRET assert creds.scopes == SDM_SCOPES - - -# This tests needs to be adjusted to remove lingering tasks -@pytest.mark.parametrize("expected_lingering_tasks", [True]) -@pytest.mark.parametrize( - "token_expiration_time", - [time.time() - 7 * 86400], - ids=["expires-in-past"], -) -async def test_auth_expired_token( - hass: HomeAssistant, - aioclient_mock: AiohttpClientMocker, - setup_platform: PlatformSetup, - token_expiration_time: float, -) -> None: - """Verify behavior of an expired token.""" - # Prepare a token refresh response - aioclient_mock.post( - OAUTH2_TOKEN, - json={ - "access_token": FAKE_UPDATED_TOKEN, - "expires_at": time.time() + 86400, - "expires_in": 86400, - }, - ) - # Prepare to capture credentials in API request. Empty payloads just mean - # no devices or structures are loaded. - aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={}) - aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={}) - - # Prepare to capture credentials for Subscriber - captured_creds = None - - def async_new_subscriber( - credentials: Credentials, - ) -> Mock: - """Capture credentials for tests.""" - nonlocal captured_creds - captured_creds = credentials - return AsyncMock() - - with patch( - "google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient", - side_effect=async_new_subscriber, - ) as new_subscriber_mock: - await setup_platform() - - calls = aioclient_mock.mock_calls - assert len(calls) == 3 - # Verify refresh token call to get an updated token - (method, url, data, headers) = calls[0] - assert data == { - "client_id": CLIENT_ID, - "client_secret": CLIENT_SECRET, - "grant_type": "refresh_token", - "refresh_token": FAKE_REFRESH_TOKEN, - } - # Verify API requests are made with the new token - (method, url, data, headers) = calls[1] - assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"} - (method, url, data, headers) = calls[2] - assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"} - - # The subscriber is created with a token that is expired. Verify that the - # credential is expired so the subscriber knows it needs to refresh it. - assert len(new_subscriber_mock.mock_calls) == 1 - assert captured_creds - creds = captured_creds - assert creds.token == FAKE_TOKEN - assert creds.refresh_token == FAKE_REFRESH_TOKEN - assert int(dt_util.as_timestamp(creds.expiry)) == int(token_expiration_time) - assert not creds.valid - assert creds.expired - assert creds.token_uri == OAUTH2_TOKEN - assert creds.client_id == CLIENT_ID - assert creds.client_secret == CLIENT_SECRET - assert creds.scopes == SDM_SCOPES