diff --git a/homeassistant/components/smarla/config_flow.py b/homeassistant/components/smarla/config_flow.py index 3ff31ccbee1..5b210783b43 100644 --- a/homeassistant/components/smarla/config_flow.py +++ b/homeassistant/components/smarla/config_flow.py @@ -5,7 +5,6 @@ from __future__ import annotations from typing import Any from pysmarlaapi import Connection -from pysmarlaapi.classes import AuthToken import voluptuous as vol from homeassistant.config_entries import ConfigFlow, ConfigFlowResult @@ -21,7 +20,7 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 - async def _handle_token(self, token: str) -> tuple[dict[str, str], AuthToken]: + async def _handle_token(self, token: str) -> tuple[dict[str, str], str]: """Handle the token input.""" errors: dict[str, str] = {} @@ -29,13 +28,13 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN): conn = Connection(url=HOST, token_b64=token) except ValueError: errors["base"] = "malformed_token" - return (errors, None) + return (errors, "") if not await conn.refresh_token(): errors["base"] = "invalid_auth" - return (errors, None) + return (errors, "") - return (errors, conn.token) + return (errors, conn.token.serialNumber) async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -45,14 +44,14 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN): if user_input is not None: raw_token = user_input[CONF_ACCESS_TOKEN] - errors, token = await self._handle_token(token=raw_token) + errors, serial_number = await self._handle_token(token=raw_token) if not errors: - await self.async_set_unique_id(token.serialNumber) + await self.async_set_unique_id(serial_number) self._abort_if_unique_id_configured() return self.async_create_entry( - title=token.serialNumber, + title=serial_number, data={CONF_ACCESS_TOKEN: raw_token}, ) diff --git a/tests/components/smarla/conftest.py b/tests/components/smarla/conftest.py index c6b3d0961e4..3ecbd751834 100644 --- a/tests/components/smarla/conftest.py +++ b/tests/components/smarla/conftest.py @@ -2,15 +2,15 @@ from __future__ import annotations -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +from pysmarlaapi.classes import AuthToken import pytest -from homeassistant.components.smarla.config_flow import Connection from homeassistant.components.smarla.const import DOMAIN from homeassistant.config_entries import SOURCE_USER -from . import MOCK_SERIAL_NUMBER +from . import MOCK_ACCESS_TOKEN_JSON, MOCK_SERIAL_NUMBER, MOCK_USER_INPUT from tests.common import MockConfigEntry @@ -22,23 +22,18 @@ def mock_config_entry() -> MockConfigEntry: domain=DOMAIN, unique_id=MOCK_SERIAL_NUMBER, source=SOURCE_USER, + data=MOCK_USER_INPUT, ) @pytest.fixture -def mock_refresh_token_success(): - """Patch Connection.refresh_token to return True.""" - with patch.object(Connection, "refresh_token", new=AsyncMock(return_value=True)): - yield - - -@pytest.fixture -def malformed_token_patch(): - """Patch Connection to raise exception.""" - return patch.object(Connection, "__init__", side_effect=ValueError) - - -@pytest.fixture -def invalid_auth_patch(): - """Patch Connection.refresh_token to return False.""" - return patch.object(Connection, "refresh_token", new=AsyncMock(return_value=False)) +def mock_cf_connection(): + """Patch config_flow Connection object.""" + with patch( + "homeassistant.components.smarla.config_flow.Connection" + ) as mock_connection: + connection = MagicMock() + connection.token = AuthToken.from_json(MOCK_ACCESS_TOKEN_JSON) + connection.refresh_token = AsyncMock(return_value=True) + mock_connection.return_value = connection + yield connection diff --git a/tests/components/smarla/test_config_flow.py b/tests/components/smarla/test_config_flow.py index 7fef73eebae..ff80a603ee3 100644 --- a/tests/components/smarla/test_config_flow.py +++ b/tests/components/smarla/test_config_flow.py @@ -1,5 +1,7 @@ """Test config flow for Swing2Sleep Smarla integration.""" +from unittest.mock import AsyncMock, patch + import pytest from homeassistant.components.smarla.const import DOMAIN @@ -12,7 +14,7 @@ from . import MOCK_SERIAL_NUMBER, MOCK_USER_INPUT from tests.common import MockConfigEntry -async def test_config_flow(hass: HomeAssistant, mock_refresh_token_success) -> None: +async def test_config_flow(hass: HomeAssistant, mock_cf_connection) -> None: """Test creating a config entry.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER} @@ -21,8 +23,8 @@ async def test_config_flow(hass: HomeAssistant, mock_refresh_token_success) -> N assert result["type"] is FlowResultType.FORM assert result["step_id"] == "user" - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=MOCK_USER_INPUT + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=MOCK_USER_INPUT ) assert result["type"] is FlowResultType.CREATE_ENTRY @@ -32,11 +34,21 @@ async def test_config_flow(hass: HomeAssistant, mock_refresh_token_success) -> N @pytest.mark.parametrize("error", ["malformed_token", "invalid_auth"]) -async def test_form_error( - hass: HomeAssistant, request: pytest.FixtureRequest, error: str -) -> None: +async def test_form_error(hass: HomeAssistant, error: str, mock_cf_connection) -> None: """Test we show user form on invalid auth.""" - error_patch = request.getfixturevalue(f"{error}_patch") + match error: + case "malformed_token": + error_patch = patch( + "homeassistant.components.smarla.config_flow.Connection", + side_effect=ValueError, + ) + case "invalid_auth": + error_patch = patch.object( + mock_cf_connection, + "refresh_token", + new=AsyncMock(return_value=False), + ) + with error_patch: result = await hass.config_entries.flow.async_init( DOMAIN, @@ -45,18 +57,18 @@ async def test_form_error( ) assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" assert result["errors"] == {"base": error} - request.getfixturevalue("mock_refresh_token_success") - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data=MOCK_USER_INPUT + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=MOCK_USER_INPUT ) assert result["type"] is FlowResultType.CREATE_ENTRY async def test_device_exists_abort( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_refresh_token_success + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_cf_connection ) -> None: """Test we abort config flow if Smarla device already configured.""" mock_config_entry.add_to_hass(hass)