diff --git a/homeassistant/components/smartthings/__init__.py b/homeassistant/components/smartthings/__init__.py index 5cc7b3e2c36..a5e138639de 100644 --- a/homeassistant/components/smartthings/__init__.py +++ b/homeassistant/components/smartthings/__init__.py @@ -4,10 +4,11 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass +from http import HTTPStatus import logging from typing import TYPE_CHECKING, Any, cast -from aiohttp import ClientError +from aiohttp import ClientResponseError from pysmartthings import ( Attribute, Capability, @@ -102,7 +103,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: SmartThingsConfigEntry) try: await session.async_ensure_token_valid() - except ClientError as err: + except ClientResponseError as err: + if err.status == HTTPStatus.BAD_REQUEST: + raise ConfigEntryAuthFailed("Token not valid, trigger renewal") from err raise ConfigEntryNotReady from err client = SmartThings(session=async_get_clientsession(hass)) diff --git a/tests/components/smartthings/test_init.py b/tests/components/smartthings/test_init.py index 2083bb7ea24..3eaa038027d 100644 --- a/tests/components/smartthings/test_init.py +++ b/tests/components/smartthings/test_init.py @@ -1,7 +1,8 @@ """Tests for the SmartThings component init module.""" -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch +from aiohttp import ClientResponseError, RequestInfo from pysmartthings import ( Attribute, Capability, @@ -264,6 +265,57 @@ async def test_removing_stale_devices( assert not device_registry.async_get_device({(DOMAIN, "aaa-bbb-ccc")}) +@pytest.mark.parametrize("device_fixture", ["da_ac_rac_000001"]) +async def test_refreshing_expired_token( + hass: HomeAssistant, + devices: AsyncMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test removing stale devices.""" + with patch( + "homeassistant.components.smartthings.OAuth2Session.async_ensure_token_valid", + side_effect=ClientResponseError( + request_info=RequestInfo( + url="http://example.com", + method="GET", + headers={}, + real_url="http://example.com", + ), + status=400, + history=(), + ), + ): + await setup_integration(hass, mock_config_entry) + + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + assert len(hass.config_entries.flow.async_progress()) == 1 + + +@pytest.mark.parametrize("device_fixture", ["da_ac_rac_000001"]) +async def test_error_refreshing_token( + hass: HomeAssistant, + devices: AsyncMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test removing stale devices.""" + with patch( + "homeassistant.components.smartthings.OAuth2Session.async_ensure_token_valid", + side_effect=ClientResponseError( + request_info=RequestInfo( + url="http://example.com", + method="GET", + headers={}, + real_url="http://example.com", + ), + status=500, + history=(), + ), + ): + await setup_integration(hass, mock_config_entry) + + assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY + + async def test_hub_via_device( hass: HomeAssistant, snapshot: SnapshotAssertion,