Catch Hydrawise authorization errors in the correct place (#132727)

This commit is contained in:
David Knowles 2024-12-10 02:38:34 -05:00 committed by Franck Nijhof
parent e4765c40fe
commit 60e8a38ba3
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 41 additions and 14 deletions

View File

@ -6,7 +6,7 @@ from collections.abc import Callable, Mapping
from typing import Any from typing import Any
from aiohttp import ClientError from aiohttp import ClientError
from pydrawise import auth, client from pydrawise import auth as pydrawise_auth, client
from pydrawise.exceptions import NotAuthorizedError from pydrawise.exceptions import NotAuthorizedError
import voluptuous as vol import voluptuous as vol
@ -29,16 +29,21 @@ class HydrawiseConfigFlow(ConfigFlow, domain=DOMAIN):
on_failure: Callable[[str], ConfigFlowResult], on_failure: Callable[[str], ConfigFlowResult],
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Create the config entry.""" """Create the config entry."""
# Verify that the provided credentials work.""" # Verify that the provided credentials work."""
api = client.Hydrawise(auth.Auth(username, password)) auth = pydrawise_auth.Auth(username, password)
try: try:
# Don't fetch zones because we don't need them yet. await auth.token()
user = await api.get_user(fetch_zones=False)
except NotAuthorizedError: except NotAuthorizedError:
return on_failure("invalid_auth") return on_failure("invalid_auth")
except TimeoutError: except TimeoutError:
return on_failure("timeout_connect") return on_failure("timeout_connect")
try:
api = client.Hydrawise(auth)
# Don't fetch zones because we don't need them yet.
user = await api.get_user(fetch_zones=False)
except TimeoutError:
return on_failure("timeout_connect")
except ClientError as ex: except ClientError as ex:
LOGGER.error("Unable to connect to Hydrawise cloud service: %s", ex) LOGGER.error("Unable to connect to Hydrawise cloud service: %s", ex)
return on_failure("cannot_connect") return on_failure("cannot_connect")

View File

@ -56,7 +56,6 @@ def mock_legacy_pydrawise(
@pytest.fixture @pytest.fixture
def mock_pydrawise( def mock_pydrawise(
mock_auth: AsyncMock,
user: User, user: User,
controller: Controller, controller: Controller,
zones: list[Zone], zones: list[Zone],

View File

@ -21,6 +21,7 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry")
async def test_form( async def test_form(
hass: HomeAssistant, hass: HomeAssistant,
mock_setup_entry: AsyncMock, mock_setup_entry: AsyncMock,
mock_auth: AsyncMock,
mock_pydrawise: AsyncMock, mock_pydrawise: AsyncMock,
user: User, user: User,
) -> None: ) -> None:
@ -46,11 +47,12 @@ async def test_form(
CONF_PASSWORD: "__password__", CONF_PASSWORD: "__password__",
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
mock_pydrawise.get_user.assert_called_once_with(fetch_zones=False) mock_auth.token.assert_awaited_once_with()
mock_pydrawise.get_user.assert_awaited_once_with(fetch_zones=False)
async def test_form_api_error( async def test_form_api_error(
hass: HomeAssistant, mock_pydrawise: AsyncMock, user: User hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock, user: User
) -> None: ) -> None:
"""Test we handle API errors.""" """Test we handle API errors."""
mock_pydrawise.get_user.side_effect = ClientError("XXX") mock_pydrawise.get_user.side_effect = ClientError("XXX")
@ -71,8 +73,29 @@ async def test_form_api_error(
assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["type"] is FlowResultType.CREATE_ENTRY
async def test_form_connect_timeout( async def test_form_auth_connect_timeout(
hass: HomeAssistant, mock_pydrawise: AsyncMock, user: User hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock
) -> None:
"""Test we handle API errors."""
mock_auth.token.side_effect = TimeoutError
init_result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
data = {CONF_USERNAME: "asdf@asdf.com", CONF_PASSWORD: "__password__"}
result = await hass.config_entries.flow.async_configure(
init_result["flow_id"], data
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": "timeout_connect"}
mock_auth.token.reset_mock(side_effect=True)
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data)
assert result2["type"] is FlowResultType.CREATE_ENTRY
async def test_form_client_connect_timeout(
hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock, user: User
) -> None: ) -> None:
"""Test we handle API errors.""" """Test we handle API errors."""
mock_pydrawise.get_user.side_effect = TimeoutError mock_pydrawise.get_user.side_effect = TimeoutError
@ -94,10 +117,10 @@ async def test_form_connect_timeout(
async def test_form_not_authorized_error( async def test_form_not_authorized_error(
hass: HomeAssistant, mock_pydrawise: AsyncMock, user: User hass: HomeAssistant, mock_auth: AsyncMock, mock_pydrawise: AsyncMock
) -> None: ) -> None:
"""Test we handle API errors.""" """Test we handle API errors."""
mock_pydrawise.get_user.side_effect = NotAuthorizedError mock_auth.token.side_effect = NotAuthorizedError
init_result = await hass.config_entries.flow.async_init( init_result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -109,8 +132,7 @@ async def test_form_not_authorized_error(
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": "invalid_auth"} assert result["errors"] == {"base": "invalid_auth"}
mock_pydrawise.get_user.reset_mock(side_effect=True) mock_auth.token.reset_mock(side_effect=True)
mock_pydrawise.get_user.return_value = user
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data) result2 = await hass.config_entries.flow.async_configure(result["flow_id"], data)
assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["type"] is FlowResultType.CREATE_ENTRY
@ -118,6 +140,7 @@ async def test_form_not_authorized_error(
async def test_reauth( async def test_reauth(
hass: HomeAssistant, hass: HomeAssistant,
user: User, user: User,
mock_auth: AsyncMock,
mock_pydrawise: AsyncMock, mock_pydrawise: AsyncMock,
) -> None: ) -> None:
"""Test that re-authorization works.""" """Test that re-authorization works."""