diff --git a/homeassistant/components/husqvarna_automower/api.py b/homeassistant/components/husqvarna_automower/api.py index f1d3e1ef4fa..8a9a31b926a 100644 --- a/homeassistant/components/husqvarna_automower/api.py +++ b/homeassistant/components/husqvarna_automower/api.py @@ -7,6 +7,7 @@ from aioautomower.auth import AbstractAuth from aioautomower.const import API_BASE_URL from aiohttp import ClientSession +from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.helpers import config_entry_oauth2_flow _LOGGER = logging.getLogger(__name__) @@ -28,3 +29,16 @@ class AsyncConfigEntryAuth(AbstractAuth): """Return a valid access token.""" await self._oauth_session.async_ensure_token_valid() return cast(str, self._oauth_session.token["access_token"]) + + +class AsyncConfigFlowAuth(AbstractAuth): + """Provide Automower AbstractAuth for the config flow.""" + + def __init__(self, websession: ClientSession, token: dict) -> None: + """Initialize Husqvarna Automower auth.""" + super().__init__(websession, API_BASE_URL) + self.token: dict = token + + async def async_get_access_token(self) -> str: + """Return a valid access token.""" + return cast(str, self.token[CONF_ACCESS_TOKEN]) diff --git a/homeassistant/components/husqvarna_automower/config_flow.py b/homeassistant/components/husqvarna_automower/config_flow.py index 3e76b9ac812..4da3bd14089 100644 --- a/homeassistant/components/husqvarna_automower/config_flow.py +++ b/homeassistant/components/husqvarna_automower/config_flow.py @@ -4,12 +4,15 @@ from collections.abc import Mapping import logging from typing import Any +from aioautomower.session import AutomowerSession from aioautomower.utils import structure_token from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, CONF_TOKEN -from homeassistant.helpers import config_entry_oauth2_flow +from homeassistant.helpers import aiohttp_client, config_entry_oauth2_flow +from homeassistant.util import dt as dt_util +from .api import AsyncConfigFlowAuth from .const import DOMAIN, NAME _LOGGER = logging.getLogger(__name__) @@ -46,9 +49,20 @@ class HusqvarnaConfigFlowHandler( self._abort_if_unique_id_configured() + websession = aiohttp_client.async_get_clientsession(self.hass) + tz = await dt_util.async_get_time_zone(str(dt_util.DEFAULT_TIME_ZONE)) + automower_api = AutomowerSession(AsyncConfigFlowAuth(websession, token), tz) + try: + data = await automower_api.get_status() + except Exception: # noqa: BLE001 + return self.async_abort(reason="unknown") + if data == {}: + return self.async_abort(reason="no_mower_connected") + structured_token = structure_token(token[CONF_ACCESS_TOKEN]) first_name = structured_token.user.first_name last_name = structured_token.user.last_name + return self.async_create_entry( title=f"{NAME} of {first_name} {last_name}", data=data, diff --git a/homeassistant/components/husqvarna_automower/strings.json b/homeassistant/components/husqvarna_automower/strings.json index 149d53f8783..d4c91e29f7d 100644 --- a/homeassistant/components/husqvarna_automower/strings.json +++ b/homeassistant/components/husqvarna_automower/strings.json @@ -27,7 +27,9 @@ "oauth_failed": "[%key:common::config_flow::abort::oauth2_failed%]", "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", "wrong_account": "You can only reauthenticate this entry with the same Husqvarna account.", - "missing_amc_scope": "The `Authentication API` and the `Automower Connect API` are not connected to your application in the Husqvarna Developer Portal." + "no_mower_connected": "No mowers connected to this account.", + "missing_amc_scope": "The `Authentication API` and the `Automower Connect API` are not connected to your application in the Husqvarna Developer Portal.", + "unknown": "[%key:common::config_flow::error::unknown%]" }, "create_entry": { "default": "[%key:common::config_flow::create_entry::authenticated%]" diff --git a/tests/components/husqvarna_automower/fixtures/empty.json b/tests/components/husqvarna_automower/fixtures/empty.json new file mode 100644 index 00000000000..22f4a272fc1 --- /dev/null +++ b/tests/components/husqvarna_automower/fixtures/empty.json @@ -0,0 +1 @@ +{ "data": [] } diff --git a/tests/components/husqvarna_automower/test_config_flow.py b/tests/components/husqvarna_automower/test_config_flow.py index 31e8a9afcbd..d91078d80a2 100644 --- a/tests/components/husqvarna_automower/test_config_flow.py +++ b/tests/components/husqvarna_automower/test_config_flow.py @@ -2,6 +2,8 @@ from unittest.mock import AsyncMock, patch +from aioautomower.const import API_BASE_URL +from aioautomower.session import AutomowerEndpoint import pytest from homeassistant import config_entries @@ -18,16 +20,18 @@ from homeassistant.helpers import config_entry_oauth2_flow from . import setup_integration from .const import CLIENT_ID, USER_ID -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, load_fixture from tests.test_util.aiohttp import AiohttpClientMocker from tests.typing import ClientSessionGenerator @pytest.mark.parametrize( - ("new_scope", "amount"), + ("new_scope", "fixture", "exception", "amount"), [ - ("iam:read amc:api", 1), - ("iam:read", 0), + ("iam:read amc:api", "mower.json", None, 1), + ("iam:read amc:api", "mower.json", Exception, 0), + ("iam:read", "mower.json", None, 0), + ("iam:read amc:api", "empty.json", None, 0), ], ) @pytest.mark.usefixtures("current_request_with_host") @@ -38,6 +42,8 @@ async def test_full_flow( jwt: str, new_scope: str, amount: int, + fixture: str, + exception: Exception | None, ) -> None: """Check full flow.""" result = await hass.config_entries.flow.async_init( @@ -76,11 +82,17 @@ async def test_full_flow( "expires_at": 1697753347, }, ) - - with patch( - "homeassistant.components.husqvarna_automower.async_setup_entry", - return_value=True, - ) as mock_setup: + aioclient_mock.get( + f"{API_BASE_URL}/{AutomowerEndpoint.mowers}", + text=load_fixture(fixture, DOMAIN), + exc=exception, + ) + with ( + patch( + "homeassistant.components.husqvarna_automower.async_setup_entry", + return_value=True, + ) as mock_setup, + ): await hass.config_entries.flow.async_configure(result["flow_id"]) assert len(hass.config_entries.async_entries(DOMAIN)) == amount