Refactor tests

This commit is contained in:
Robin Lintermann 2025-05-20 09:16:48 +00:00
parent f0a3ecd27b
commit acc32eea3e
4 changed files with 47 additions and 33 deletions

View File

@ -20,7 +20,7 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
async def _handle_token(self, token: str) -> tuple[dict[str, str], str]: async def _handle_token(self, token: str) -> tuple[dict[str, str], str | None]:
"""Handle the token input.""" """Handle the token input."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
@ -28,11 +28,11 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
conn = Connection(url=HOST, token_b64=token) conn = Connection(url=HOST, token_b64=token)
except ValueError: except ValueError:
errors["base"] = "malformed_token" errors["base"] = "malformed_token"
return (errors, "") return (errors, None)
if not await conn.refresh_token(): if not await conn.refresh_token():
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
return (errors, "") return (errors, None)
return (errors, conn.token.serialNumber) return (errors, conn.token.serialNumber)
@ -46,7 +46,7 @@ class SmarlaConfigFlow(ConfigFlow, domain=DOMAIN):
raw_token = user_input[CONF_ACCESS_TOKEN] raw_token = user_input[CONF_ACCESS_TOKEN]
errors, serial_number = await self._handle_token(token=raw_token) errors, serial_number = await self._handle_token(token=raw_token)
if not errors: if not errors and serial_number is not None:
await self.async_set_unique_id(serial_number) await self.async_set_unique_id(serial_number)
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()

View File

@ -18,3 +18,5 @@ MOCK_ACCESS_TOKEN = base64.b64encode(
).decode() ).decode()
MOCK_USER_INPUT = {CONF_ACCESS_TOKEN: MOCK_ACCESS_TOKEN} MOCK_USER_INPUT = {CONF_ACCESS_TOKEN: MOCK_ACCESS_TOKEN}
MOCK_URL = "https://someurl.net"

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import patch
from pysmarlaapi.classes import AuthToken from pysmarlaapi.classes import AuthToken
import pytest import pytest
@ -10,7 +10,7 @@ import pytest
from homeassistant.components.smarla.const import DOMAIN from homeassistant.components.smarla.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER from homeassistant.config_entries import SOURCE_USER
from . import MOCK_ACCESS_TOKEN_JSON, MOCK_SERIAL_NUMBER, MOCK_USER_INPUT from . import MOCK_ACCESS_TOKEN_JSON, MOCK_SERIAL_NUMBER, MOCK_URL, MOCK_USER_INPUT
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -31,15 +31,15 @@ def mock_connection():
"""Patch Connection object.""" """Patch Connection object."""
with ( with (
patch( patch(
"homeassistant.components.smarla.config_flow.Connection" "homeassistant.components.smarla.config_flow.Connection", autospec=True
) as mock_connection, ) as mock_connection,
patch( patch(
"homeassistant.components.smarla.Connection", "homeassistant.components.smarla.Connection",
mock_connection, mock_connection,
), ),
): ):
connection = MagicMock() connection = mock_connection.return_value
connection.url = MOCK_URL
connection.token = AuthToken.from_json(MOCK_ACCESS_TOKEN_JSON) connection.token = AuthToken.from_json(MOCK_ACCESS_TOKEN_JSON)
connection.refresh_token = AsyncMock(return_value=True) connection.refresh_token.return_value = True
mock_connection.return_value = connection
yield connection yield connection

View File

@ -2,8 +2,6 @@
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.smarla.const import DOMAIN from homeassistant.components.smarla.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER from homeassistant.config_entries import SOURCE_USER
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -17,14 +15,16 @@ from tests.common import MockConfigEntry
async def test_config_flow(hass: HomeAssistant, mock_connection) -> None: async def test_config_flow(hass: HomeAssistant, mock_connection) -> None:
"""Test creating a config entry.""" """Test creating a config entry."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER} DOMAIN,
context={"source": SOURCE_USER},
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_INPUT result["flow_id"],
user_input=MOCK_USER_INPUT,
) )
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY
@ -33,23 +33,11 @@ async def test_config_flow(hass: HomeAssistant, mock_connection) -> None:
assert result["result"].unique_id == MOCK_SERIAL_NUMBER assert result["result"].unique_id == MOCK_SERIAL_NUMBER
@pytest.mark.parametrize("error", ["malformed_token", "invalid_auth"]) async def test_malformed_token(hass: HomeAssistant, mock_connection) -> None:
async def test_form_error(hass: HomeAssistant, error: str, mock_connection) -> None: """Test we show user form on malformed token input."""
"""Test we show user form on invalid auth.""" with patch(
match error: "homeassistant.components.smarla.config_flow.Connection", side_effect=ValueError
case "malformed_token": ):
error_patch = patch(
"homeassistant.components.smarla.config_flow.Connection",
side_effect=ValueError,
)
case "invalid_auth":
error_patch = patch.object(
mock_connection,
"refresh_token",
new=AsyncMock(return_value=False),
)
with error_patch:
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": SOURCE_USER}, context={"source": SOURCE_USER},
@ -58,10 +46,34 @@ async def test_form_error(hass: HomeAssistant, error: str, mock_connection) -> N
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
assert result["errors"] == {"base": error} assert result["errors"] == {"base": "malformed_token"}
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=MOCK_USER_INPUT result["flow_id"],
user_input=MOCK_USER_INPUT,
)
assert result["type"] is FlowResultType.CREATE_ENTRY
async def test_invalid_auth(hass: HomeAssistant, mock_connection) -> None:
"""Test we show user form on invalid auth."""
with patch.object(
mock_connection, "refresh_token", new=AsyncMock(return_value=False)
):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data=MOCK_USER_INPUT,
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] == {"base": "invalid_auth"}
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=MOCK_USER_INPUT,
) )
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY