diff --git a/homeassistant/components/sleepiq/config_flow.py b/homeassistant/components/sleepiq/config_flow.py index dffb30f39d7..47e08fdfd5b 100644 --- a/homeassistant/components/sleepiq/config_flow.py +++ b/homeassistant/components/sleepiq/config_flow.py @@ -1,6 +1,7 @@ """Config flow to configure SleepIQ component.""" from __future__ import annotations +import logging from typing import Any from asyncsleepiq import AsyncSleepIQ, SleepIQLoginException, SleepIQTimeoutException @@ -14,6 +15,8 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import DOMAIN +_LOGGER = logging.getLogger(__name__) + class SleepIQFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle a SleepIQ config flow.""" @@ -28,6 +31,10 @@ class SleepIQFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): await self.async_set_unique_id(import_config[CONF_USERNAME].lower()) self._abort_if_unique_id_configured() + if error := await try_connection(self.hass, import_config): + _LOGGER.error("Could not authenticate with SleepIQ server: %s", error) + return self.async_abort(reason=error) + return self.async_create_entry( title=import_config[CONF_USERNAME], data=import_config ) @@ -43,26 +50,23 @@ class SleepIQFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): await self.async_set_unique_id(user_input[CONF_USERNAME].lower()) self._abort_if_unique_id_configured() - try: - await try_connection(self.hass, user_input) - except SleepIQLoginException: - errors["base"] = "invalid_auth" - except SleepIQTimeoutException: - errors["base"] = "cannot_connect" + if error := await try_connection(self.hass, user_input): + errors["base"] = error else: return self.async_create_entry( title=user_input[CONF_USERNAME], data=user_input ) + else: + user_input = {} + return self.async_show_form( step_id="user", data_schema=vol.Schema( { vol.Required( CONF_USERNAME, - default=user_input.get(CONF_USERNAME) - if user_input is not None - else "", + default=user_input.get(CONF_USERNAME), ): str, vol.Required(CONF_PASSWORD): str, } @@ -72,10 +76,17 @@ class SleepIQFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ) -async def try_connection(hass: HomeAssistant, user_input: dict[str, Any]) -> None: +async def try_connection(hass: HomeAssistant, user_input: dict[str, Any]) -> str | None: """Test if the given credentials can successfully login to SleepIQ.""" client_session = async_get_clientsession(hass) gateway = AsyncSleepIQ(client_session=client_session) - await gateway.login(user_input[CONF_USERNAME], user_input[CONF_PASSWORD]) + try: + await gateway.login(user_input[CONF_USERNAME], user_input[CONF_PASSWORD]) + except SleepIQLoginException: + return "invalid_auth" + except SleepIQTimeoutException: + return "cannot_connect" + + return None diff --git a/tests/components/sleepiq/test_config_flow.py b/tests/components/sleepiq/test_config_flow.py index b2554ea968e..516a783f302 100644 --- a/tests/components/sleepiq/test_config_flow.py +++ b/tests/components/sleepiq/test_config_flow.py @@ -2,6 +2,7 @@ from unittest.mock import patch from asyncsleepiq import SleepIQLoginException, SleepIQTimeoutException +import pytest from homeassistant import config_entries, data_entry_flow, setup from homeassistant.components.sleepiq.const import DOMAIN @@ -26,6 +27,21 @@ async def test_import(hass: HomeAssistant) -> None: assert entry.data[CONF_PASSWORD] == SLEEPIQ_CONFIG[CONF_PASSWORD] +@pytest.mark.parametrize( + "side_effect", [SleepIQLoginException, SleepIQTimeoutException] +) +async def test_import_failure(hass: HomeAssistant, side_effect) -> None: + """Test that we won't import a config entry on login failure.""" + with patch( + "asyncsleepiq.AsyncSleepIQ.login", + side_effect=side_effect, + ): + assert await setup.async_setup_component(hass, DOMAIN, {DOMAIN: SLEEPIQ_CONFIG}) + await hass.async_block_till_done() + + assert len(hass.config_entries.async_entries(DOMAIN)) == 0 + + async def test_show_set_form(hass: HomeAssistant) -> None: """Test that the setup form is served.""" with patch("asyncsleepiq.AsyncSleepIQ.login"): @@ -37,11 +53,18 @@ async def test_show_set_form(hass: HomeAssistant) -> None: assert result["step_id"] == "user" -async def test_login_invalid_auth(hass: HomeAssistant) -> None: - """Test we show user form with appropriate error on login failure.""" +@pytest.mark.parametrize( + "side_effect,error", + [ + (SleepIQLoginException, "invalid_auth"), + (SleepIQTimeoutException, "cannot_connect"), + ], +) +async def test_login_failure(hass: HomeAssistant, side_effect, error) -> None: + """Test that we show user form with appropriate error on login failure.""" with patch( "asyncsleepiq.AsyncSleepIQ.login", - side_effect=SleepIQLoginException, + side_effect=side_effect, ): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER}, data=SLEEPIQ_CONFIG @@ -49,22 +72,7 @@ async def test_login_invalid_auth(hass: HomeAssistant) -> None: assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["step_id"] == "user" - assert result["errors"] == {"base": "invalid_auth"} - - -async def test_login_cannot_connect(hass: HomeAssistant) -> None: - """Test we show user form with appropriate error on login failure.""" - with patch( - "asyncsleepiq.AsyncSleepIQ.login", - side_effect=SleepIQTimeoutException, - ): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER}, data=SLEEPIQ_CONFIG - ) - - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["step_id"] == "user" - assert result["errors"] == {"base": "cannot_connect"} + assert result["errors"] == {"base": error} async def test_success(hass: HomeAssistant) -> None: