Refactoring

This commit is contained in:
Robin Lintermann 2025-05-15 14:20:18 +00:00
parent f14874bc3c
commit 6bbd357547
3 changed files with 44 additions and 38 deletions

View File

@ -5,7 +5,6 @@ from __future__ import annotations
from typing import Any from typing import Any
from pysmarlaapi import Connection from pysmarlaapi import Connection
from pysmarlaapi.classes import AuthToken
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
@ -21,7 +20,7 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
async def _handle_token(self, token: str) -> tuple[dict[str, str], AuthToken]: async def _handle_token(self, token: str) -> tuple[dict[str, str], str]:
"""Handle the token input.""" """Handle the token input."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
@ -29,13 +28,13 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
conn = Connection(url=HOST, token_b64=token) conn = Connection(url=HOST, token_b64=token)
except ValueError: except ValueError:
errors["base"] = "malformed_token" errors["base"] = "malformed_token"
return (errors, None) return (errors, "")
if not await conn.refresh_token(): if not await conn.refresh_token():
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
return (errors, None) return (errors, "")
return (errors, conn.token) return (errors, conn.token.serialNumber)
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -45,14 +44,14 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
raw_token = user_input[CONF_ACCESS_TOKEN] raw_token = user_input[CONF_ACCESS_TOKEN]
errors, token = await self._handle_token(token=raw_token) errors, serial_number = await self._handle_token(token=raw_token)
if not errors: if not errors:
await self.async_set_unique_id(token.serialNumber) await self.async_set_unique_id(serial_number)
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
return self.async_create_entry( return self.async_create_entry(
title=token.serialNumber, title=serial_number,
data={CONF_ACCESS_TOKEN: raw_token}, data={CONF_ACCESS_TOKEN: raw_token},
) )

View File

@ -2,15 +2,15 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from pysmarlaapi.classes import AuthToken
import pytest import pytest
from homeassistant.components.smarla.config_flow import Connection
from homeassistant.components.smarla.const import DOMAIN from homeassistant.components.smarla.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER from homeassistant.config_entries import SOURCE_USER
from . import MOCK_SERIAL_NUMBER from . import MOCK_ACCESS_TOKEN_JSON, MOCK_SERIAL_NUMBER, MOCK_USER_INPUT
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -22,23 +22,18 @@ def mock_config_entry() -> MockConfigEntry:
domain=DOMAIN, domain=DOMAIN,
unique_id=MOCK_SERIAL_NUMBER, unique_id=MOCK_SERIAL_NUMBER,
source=SOURCE_USER, source=SOURCE_USER,
data=MOCK_USER_INPUT,
) )
@pytest.fixture @pytest.fixture
def mock_refresh_token_success(): def mock_cf_connection():
"""Patch Connection.refresh_token to return True.""" """Patch config_flow Connection object."""
with patch.object(Connection, "refresh_token", new=AsyncMock(return_value=True)): with patch(
yield "homeassistant.components.smarla.config_flow.Connection"
) as mock_connection:
connection = MagicMock()
@pytest.fixture connection.token = AuthToken.from_json(MOCK_ACCESS_TOKEN_JSON)
def malformed_token_patch(): connection.refresh_token = AsyncMock(return_value=True)
"""Patch Connection to raise exception.""" mock_connection.return_value = connection
return patch.object(Connection, "__init__", side_effect=ValueError) yield connection
@pytest.fixture
def invalid_auth_patch():
"""Patch Connection.refresh_token to return False."""
return patch.object(Connection, "refresh_token", new=AsyncMock(return_value=False))

View File

@ -1,5 +1,7 @@
"""Test config flow for Swing2Sleep Smarla integration.""" """Test config flow for Swing2Sleep Smarla integration."""
from unittest.mock import AsyncMock, patch
import pytest import pytest
from homeassistant.components.smarla.const import DOMAIN from homeassistant.components.smarla.const import DOMAIN
@ -12,7 +14,7 @@ from . import MOCK_SERIAL_NUMBER, MOCK_USER_INPUT
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_config_flow(hass: HomeAssistant, mock_refresh_token_success) -> None: async def test_config_flow(hass: HomeAssistant, mock_cf_connection) -> None:
"""Test creating a config entry.""" """Test creating a config entry."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN, context={"source": SOURCE_USER}
@ -21,8 +23,8 @@ async def test_config_flow(hass: HomeAssistant, mock_refresh_token_success) -> N
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_configure(
DOMAIN, context={"source": SOURCE_USER}, data=MOCK_USER_INPUT result["flow_id"], user_input=MOCK_USER_INPUT
) )
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY
@ -32,11 +34,21 @@ async def test_config_flow(hass: HomeAssistant, mock_refresh_token_success) -> N
@pytest.mark.parametrize("error", ["malformed_token", "invalid_auth"]) @pytest.mark.parametrize("error", ["malformed_token", "invalid_auth"])
async def test_form_error( async def test_form_error(hass: HomeAssistant, error: str, mock_cf_connection) -> None:
hass: HomeAssistant, request: pytest.FixtureRequest, error: str
) -> None:
"""Test we show user form on invalid auth.""" """Test we show user form on invalid auth."""
error_patch = request.getfixturevalue(f"{error}_patch") match error:
case "malformed_token":
error_patch = patch(
"homeassistant.components.smarla.config_flow.Connection",
side_effect=ValueError,
)
case "invalid_auth":
error_patch = patch.object(
mock_cf_connection,
"refresh_token",
new=AsyncMock(return_value=False),
)
with error_patch: with error_patch:
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
@ -45,18 +57,18 @@ async def test_form_error(
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] == {"base": error} assert result["errors"] == {"base": error}
request.getfixturevalue("mock_refresh_token_success") result = await hass.config_entries.flow.async_configure(
result = await hass.config_entries.flow.async_init( result["flow_id"], user_input=MOCK_USER_INPUT
DOMAIN, context={"source": SOURCE_USER}, data=MOCK_USER_INPUT
) )
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY
async def test_device_exists_abort( async def test_device_exists_abort(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_refresh_token_success hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_cf_connection
) -> None: ) -> None:
"""Test we abort config flow if Smarla device already configured.""" """Test we abort config flow if Smarla device already configured."""
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)