Start reauth for SmartThings if token expired (#141082)

This commit is contained in:
Joost Lekkerkerker 2025-03-22 13:12:24 +01:00 committed by GitHub
parent b7d300b49f
commit 5961a46fc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 3 deletions

View File

@ -4,10 +4,11 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus
import logging import logging
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from aiohttp import ClientError from aiohttp import ClientResponseError
from pysmartthings import ( from pysmartthings import (
Attribute, Attribute,
Capability, Capability,
@ -102,7 +103,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: SmartThingsConfigEntry)
try: try:
await session.async_ensure_token_valid() await session.async_ensure_token_valid()
except ClientError as err: except ClientResponseError as err:
if err.status == HTTPStatus.BAD_REQUEST:
raise ConfigEntryAuthFailed("Token not valid, trigger renewal") from err
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
client = SmartThings(session=async_get_clientsession(hass)) client = SmartThings(session=async_get_clientsession(hass))

View File

@ -1,7 +1,8 @@
"""Tests for the SmartThings component init module.""" """Tests for the SmartThings component init module."""
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, patch
from aiohttp import ClientResponseError, RequestInfo
from pysmartthings import ( from pysmartthings import (
Attribute, Attribute,
Capability, Capability,
@ -264,6 +265,57 @@ async def test_removing_stale_devices(
assert not device_registry.async_get_device({(DOMAIN, "aaa-bbb-ccc")}) assert not device_registry.async_get_device({(DOMAIN, "aaa-bbb-ccc")})
@pytest.mark.parametrize("device_fixture", ["da_ac_rac_000001"])
async def test_refreshing_expired_token(
hass: HomeAssistant,
devices: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test removing stale devices."""
with patch(
"homeassistant.components.smartthings.OAuth2Session.async_ensure_token_valid",
side_effect=ClientResponseError(
request_info=RequestInfo(
url="http://example.com",
method="GET",
headers={},
real_url="http://example.com",
),
status=400,
history=(),
),
):
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
assert len(hass.config_entries.flow.async_progress()) == 1
@pytest.mark.parametrize("device_fixture", ["da_ac_rac_000001"])
async def test_error_refreshing_token(
hass: HomeAssistant,
devices: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test removing stale devices."""
with patch(
"homeassistant.components.smartthings.OAuth2Session.async_ensure_token_valid",
side_effect=ClientResponseError(
request_info=RequestInfo(
url="http://example.com",
method="GET",
headers={},
real_url="http://example.com",
),
status=500,
history=(),
),
):
await setup_integration(hass, mock_config_entry)
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_hub_via_device( async def test_hub_via_device(
hass: HomeAssistant, hass: HomeAssistant,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,