From acc32eea3ecb0d182fc341a8dcd56b88df004116 Mon Sep 17 00:00:00 2001 From: Robin Lintermann Date: Tue, 20 May 2025 09:16:48 +0000 Subject: [PATCH] Refactor tests --- .../components/smarla/config_flow.py | 8 +-- tests/components/smarla/__init__.py | 2 + tests/components/smarla/conftest.py | 12 ++-- tests/components/smarla/test_config_flow.py | 58 +++++++++++-------- 4 files changed, 47 insertions(+), 33 deletions(-) diff --git a/homeassistant/components/smarla/config_flow.py b/homeassistant/components/smarla/config_flow.py index 5b210783b43..011311991ae 100644 --- a/homeassistant/components/smarla/config_flow.py +++ b/homeassistant/components/smarla/config_flow.py @@ -20,7 +20,7 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 - async def _handle_token(self, token: str) -> tuple[dict[str, str], str]: + async def _handle_token(self, token: str) -> tuple[dict[str, str], str | None]: """Handle the token input.""" errors: dict[str, str] = {} @@ -28,11 +28,11 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN): conn = Connection(url=HOST, token_b64=token) except ValueError: errors["base"] = "malformed_token" - return (errors, "") + return (errors, None) if not await conn.refresh_token(): errors["base"] = "invalid_auth" - return (errors, "") + return (errors, None) return (errors, conn.token.serialNumber) @@ -46,7 +46,7 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN): raw_token = user_input[CONF_ACCESS_TOKEN] errors, serial_number = await self._handle_token(token=raw_token) - if not errors: + if not errors and serial_number is not None: await self.async_set_unique_id(serial_number) self._abort_if_unique_id_configured() diff --git a/tests/components/smarla/__init__.py b/tests/components/smarla/__init__.py index 5bb8b70f030..5232b70a991 100644 --- a/tests/components/smarla/__init__.py +++ b/tests/components/smarla/__init__.py @@ -18,3 +18,5 @@ MOCK_ACCESS_TOKEN = base64.b64encode( ).decode() MOCK_USER_INPUT = {CONF_ACCESS_TOKEN: MOCK_ACCESS_TOKEN} + +MOCK_URL = "https://someurl.net" diff --git a/tests/components/smarla/conftest.py b/tests/components/smarla/conftest.py index 5d69d234c91..87e64b90329 100644 --- a/tests/components/smarla/conftest.py +++ b/tests/components/smarla/conftest.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch from pysmarlaapi.classes import AuthToken import pytest @@ -10,7 +10,7 @@ import pytest from homeassistant.components.smarla.const import DOMAIN from homeassistant.config_entries import SOURCE_USER -from . import MOCK_ACCESS_TOKEN_JSON, MOCK_SERIAL_NUMBER, MOCK_USER_INPUT +from . import MOCK_ACCESS_TOKEN_JSON, MOCK_SERIAL_NUMBER, MOCK_URL, MOCK_USER_INPUT from tests.common import MockConfigEntry @@ -31,15 +31,15 @@ def mock_connection(): """Patch Connection object.""" with ( patch( - "homeassistant.components.smarla.config_flow.Connection" + "homeassistant.components.smarla.config_flow.Connection", autospec=True ) as mock_connection, patch( "homeassistant.components.smarla.Connection", mock_connection, ), ): - connection = MagicMock() + connection = mock_connection.return_value + connection.url = MOCK_URL connection.token = AuthToken.from_json(MOCK_ACCESS_TOKEN_JSON) - connection.refresh_token = AsyncMock(return_value=True) - mock_connection.return_value = connection + connection.refresh_token.return_value = True yield connection diff --git a/tests/components/smarla/test_config_flow.py b/tests/components/smarla/test_config_flow.py index 780b4cc7e01..73b2634e971 100644 --- a/tests/components/smarla/test_config_flow.py +++ b/tests/components/smarla/test_config_flow.py @@ -2,8 +2,6 @@ from unittest.mock import AsyncMock, patch -import pytest - from homeassistant.components.smarla.const import DOMAIN from homeassistant.config_entries import SOURCE_USER from homeassistant.core import HomeAssistant @@ -17,14 +15,16 @@ from tests.common import MockConfigEntry async def test_config_flow(hass: HomeAssistant, mock_connection) -> None: """Test creating a config entry.""" result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER} + DOMAIN, + context={"source": SOURCE_USER}, ) assert result["type"] is FlowResultType.FORM assert result["step_id"] == "user" result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input=MOCK_USER_INPUT + result["flow_id"], + user_input=MOCK_USER_INPUT, ) assert result["type"] is FlowResultType.CREATE_ENTRY @@ -33,23 +33,11 @@ async def test_config_flow(hass: HomeAssistant, mock_connection) -> None: assert result["result"].unique_id == MOCK_SERIAL_NUMBER -@pytest.mark.parametrize("error", ["malformed_token", "invalid_auth"]) -async def test_form_error(hass: HomeAssistant, error: str, mock_connection) -> None: - """Test we show user form on invalid auth.""" - match error: - case "malformed_token": - error_patch = patch( - "homeassistant.components.smarla.config_flow.Connection", - side_effect=ValueError, - ) - case "invalid_auth": - error_patch = patch.object( - mock_connection, - "refresh_token", - new=AsyncMock(return_value=False), - ) - - with error_patch: +async def test_malformed_token(hass: HomeAssistant, mock_connection) -> None: + """Test we show user form on malformed token input.""" + with patch( + "homeassistant.components.smarla.config_flow.Connection", side_effect=ValueError + ): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, @@ -58,10 +46,34 @@ async def test_form_error(hass: HomeAssistant, error: str, mock_connection) -> N assert result["type"] is FlowResultType.FORM assert result["step_id"] == "user" - assert result["errors"] == {"base": error} + assert result["errors"] == {"base": "malformed_token"} result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input=MOCK_USER_INPUT + result["flow_id"], + user_input=MOCK_USER_INPUT, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + + +async def test_invalid_auth(hass: HomeAssistant, mock_connection) -> None: + """Test we show user form on invalid auth.""" + with patch.object( + mock_connection, "refresh_token", new=AsyncMock(return_value=False) + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data=MOCK_USER_INPUT, + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + assert result["errors"] == {"base": "invalid_auth"} + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=MOCK_USER_INPUT, ) assert result["type"] is FlowResultType.CREATE_ENTRY