Change how subscription information is fetched (#148337)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Joakim Sørensen 2025-07-08 05:52:41 +01:00 committed by GitHub
parent dcf8d7f74d
commit 7a7e16bbb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 27 additions and 26 deletions

View File

@ -3,8 +3,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Any
from hass_nabucasa.payments_api import SubscriptionInfo
import voluptuous as vol import voluptuous as vol
from homeassistant.components.repairs import ( from homeassistant.components.repairs import (
@ -26,7 +26,7 @@ MAX_RETRIES = 60 # This allows for 10 minutes of retries
@callback @callback
def async_manage_legacy_subscription_issue( def async_manage_legacy_subscription_issue(
hass: HomeAssistant, hass: HomeAssistant,
subscription_info: dict[str, Any], subscription_info: SubscriptionInfo,
) -> None: ) -> None:
"""Manage the legacy subscription issue. """Manage the legacy subscription issue.
@ -50,7 +50,7 @@ class LegacySubscriptionRepairFlow(RepairsFlow):
"""Handler for an issue fixing flow.""" """Handler for an issue fixing flow."""
wait_task: asyncio.Task | None = None wait_task: asyncio.Task | None = None
_data: dict[str, Any] | None = None _data: SubscriptionInfo | None = None
async def async_step_init(self, _: None = None) -> FlowResult: async def async_step_init(self, _: None = None) -> FlowResult:
"""Handle the first step of a fix flow.""" """Handle the first step of a fix flow."""

View File

@ -8,6 +8,7 @@ from typing import Any
from aiohttp.client_exceptions import ClientError from aiohttp.client_exceptions import ClientError
from hass_nabucasa import Cloud, cloud_api from hass_nabucasa import Cloud, cloud_api
from hass_nabucasa.payments_api import PaymentsApiError, SubscriptionInfo
from .client import CloudClient from .client import CloudClient
from .const import REQUEST_TIMEOUT from .const import REQUEST_TIMEOUT
@ -15,21 +16,13 @@ from .const import REQUEST_TIMEOUT
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_subscription_info(cloud: Cloud[CloudClient]) -> dict[str, Any] | None: async def async_subscription_info(cloud: Cloud[CloudClient]) -> SubscriptionInfo | None:
"""Fetch the subscription info.""" """Fetch the subscription info."""
try: try:
async with asyncio.timeout(REQUEST_TIMEOUT): async with asyncio.timeout(REQUEST_TIMEOUT):
return await cloud_api.async_subscription_info(cloud) return await cloud.payments.subscription_info()
except TimeoutError: except PaymentsApiError as exception:
_LOGGER.error( _LOGGER.error("Failed to fetch subscription information - %s", exception)
(
"A timeout of %s was reached while trying to fetch subscription"
" information"
),
REQUEST_TIMEOUT,
)
except ClientError:
_LOGGER.error("Failed to fetch subscription information")
return None return None

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import DEFAULT, AsyncMock, MagicMock, PropertyMock, patch from unittest.mock import DEFAULT, AsyncMock, MagicMock, PropertyMock, patch
from hass_nabucasa import Cloud from hass_nabucasa import Cloud, payments_api
from hass_nabucasa.auth import CognitoAuth from hass_nabucasa.auth import CognitoAuth
from hass_nabucasa.cloudhooks import Cloudhooks from hass_nabucasa.cloudhooks import Cloudhooks
from hass_nabucasa.const import DEFAULT_SERVERS, DEFAULT_VALUES, STATE_CONNECTED from hass_nabucasa.const import DEFAULT_SERVERS, DEFAULT_VALUES, STATE_CONNECTED
@ -71,6 +71,10 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
mock_cloud.voice = MagicMock(spec=Voice) mock_cloud.voice = MagicMock(spec=Voice)
mock_cloud.files = MagicMock(spec=Files) mock_cloud.files = MagicMock(spec=Files)
mock_cloud.started = None mock_cloud.started = None
mock_cloud.payments = MagicMock(
spec=payments_api.PaymentsApi,
subscription_info=AsyncMock(),
)
mock_cloud.ice_servers = MagicMock( mock_cloud.ice_servers = MagicMock(
spec=IceServers, spec=IceServers,
async_register_ice_servers_listener=AsyncMock( async_register_ice_servers_listener=AsyncMock(

View File

@ -18,6 +18,7 @@ from hass_nabucasa.auth import (
UnknownError, UnknownError,
) )
from hass_nabucasa.const import STATE_CONNECTED from hass_nabucasa.const import STATE_CONNECTED
from hass_nabucasa.payments_api import PaymentsApiError
from hass_nabucasa.remote import CertificateStatus from hass_nabucasa.remote import CertificateStatus
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -1008,16 +1009,14 @@ async def test_websocket_subscription_info(
cloud: MagicMock, cloud: MagicMock,
setup_cloud: None, setup_cloud: None,
) -> None: ) -> None:
"""Test subscription info and connecting because valid account.""" """Test subscription info."""
aioclient_mock.get(SUBSCRIPTION_INFO_URL, json={"provider": "stripe"}) cloud.payments.subscription_info.return_value = {"provider": "stripe"}
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
mock_renew = cloud.auth.async_renew_access_token
await client.send_json({"id": 5, "type": "cloud/subscription"}) await client.send_json({"id": 5, "type": "cloud/subscription"})
response = await client.receive_json() response = await client.receive_json()
assert response["result"] == {"provider": "stripe"} assert response["result"] == {"provider": "stripe"}
assert mock_renew.call_count == 1
async def test_websocket_subscription_fail( async def test_websocket_subscription_fail(
@ -1028,7 +1027,9 @@ async def test_websocket_subscription_fail(
setup_cloud: None, setup_cloud: None,
) -> None: ) -> None:
"""Test subscription info fail.""" """Test subscription info fail."""
aioclient_mock.get(SUBSCRIPTION_INFO_URL, status=HTTPStatus.INTERNAL_SERVER_ERROR) cloud.payments.subscription_info.side_effect = PaymentsApiError(
"Failed to fetch subscription information"
)
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json({"id": 5, "type": "cloud/subscription"}) await client.send_json({"id": 5, "type": "cloud/subscription"})

View File

@ -2,7 +2,7 @@
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
from hass_nabucasa import Cloud from hass_nabucasa import Cloud, payments_api
import pytest import pytest
from homeassistant.components.cloud.subscription import ( from homeassistant.components.cloud.subscription import (
@ -22,6 +22,10 @@ async def mocked_cloud_object(hass: HomeAssistant) -> Cloud:
accounts_server="accounts.nabucasa.com", accounts_server="accounts.nabucasa.com",
auth=Mock(async_check_token=AsyncMock()), auth=Mock(async_check_token=AsyncMock()),
websession=async_get_clientsession(hass), websession=async_get_clientsession(hass),
payments=Mock(
spec=payments_api.PaymentsApi,
subscription_info=AsyncMock(),
),
) )
@ -31,14 +35,13 @@ async def test_fetching_subscription_with_timeout_error(
mocked_cloud: Cloud, mocked_cloud: Cloud,
) -> None: ) -> None:
"""Test that we handle timeout error.""" """Test that we handle timeout error."""
aioclient_mock.get( mocked_cloud.payments.subscription_info.side_effect = payments_api.PaymentsApiError(
"https://accounts.nabucasa.com/payments/subscription_info", "Timeout reached while calling API"
exc=TimeoutError(),
) )
assert await async_subscription_info(mocked_cloud) is None assert await async_subscription_info(mocked_cloud) is None
assert ( assert (
"A timeout of 10 was reached while trying to fetch subscription information" "Failed to fetch subscription information - Timeout reached while calling API"
in caplog.text in caplog.text
) )