Add check_connection parameter to cloud login methods and handle AlreadyConnectedError (#138699)

This commit is contained in:
Joakim Sørensen 2025-02-19 11:19:03 +01:00 committed by GitHub
parent 38efe94def
commit 618bdba4d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 8 deletions

View File

@ -8,14 +8,15 @@ from contextlib import suppress
import dataclasses
from functools import wraps
from http import HTTPStatus
import json
import logging
import time
from typing import Any, Concatenate
from typing import Any, Concatenate, cast
import aiohttp
from aiohttp import web
import attr
from hass_nabucasa import Cloud, auth, thingtalk
from hass_nabucasa import AlreadyConnectedError, Cloud, auth, thingtalk
from hass_nabucasa.const import STATE_DISCONNECTED
from hass_nabucasa.voice import TTS_VOICES
import voluptuous as vol
@ -64,7 +65,9 @@ from .subscription import async_subscription_info
_LOGGER = logging.getLogger(__name__)
_CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
_CLOUD_ERRORS: dict[
type[Exception], tuple[HTTPStatus, Callable[[Exception], str] | str]
] = {
TimeoutError: (
HTTPStatus.BAD_GATEWAY,
"Unable to reach the Home Assistant cloud.",
@ -133,6 +136,10 @@ def async_setup(hass: HomeAssistant) -> None:
HTTPStatus.BAD_REQUEST,
"Multi-factor authentication expired, or not started. Please try again.",
),
AlreadyConnectedError: (
HTTPStatus.CONFLICT,
lambda x: json.dumps(cast(AlreadyConnectedError, x).details),
),
}
)
@ -197,7 +204,11 @@ def _process_cloud_exception(exc: Exception, where: str) -> tuple[HTTPStatus, st
for err, value_info in _CLOUD_ERRORS.items():
if isinstance(exc, err):
err_info = value_info
status, content = value_info
err_info = (
status,
content if isinstance(content, str) else content(exc),
)
break
if err_info is None:
@ -240,6 +251,7 @@ class CloudLoginView(HomeAssistantView):
vol.All(
{
vol.Required("email"): str,
vol.Optional("check_connection", default=False): bool,
vol.Exclusive("password", "login"): str,
vol.Exclusive("code", "login"): str,
},
@ -258,7 +270,11 @@ class CloudLoginView(HomeAssistantView):
code = data.get("code")
if email and password:
await cloud.login(email, password)
await cloud.login(
email,
password,
check_connection=data["check_connection"],
)
else:
if (
@ -270,7 +286,12 @@ class CloudLoginView(HomeAssistantView):
# Voluptuous should ensure that code is not None because password is
assert code is not None
await cloud.login_verify_totp(email, code, self._mfa_tokens)
await cloud.login_verify_totp(
email,
code,
self._mfa_tokens,
check_connection=data["check_connection"],
)
self._mfa_tokens = {}
self._mfa_tokens_set_time = 0

View File

@ -145,7 +145,12 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
# Methods that we mock with a custom side effect.
async def mock_login(email: str, password: str) -> None:
async def mock_login(
email: str,
password: str,
*,
check_connection: bool = False,
) -> None:
"""Mock login.
When called, it should call the on_start callback.

View File

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import aiohttp
from freezegun.api import FrozenDateTimeFactory
from hass_nabucasa import thingtalk
from hass_nabucasa import AlreadyConnectedError, thingtalk
from hass_nabucasa.auth import (
InvalidTotpCode,
MFARequired,
@ -373,9 +373,40 @@ async def test_login_view_request_timeout(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert cloud.login.call_args[1]["check_connection"] is False
assert req.status == HTTPStatus.BAD_GATEWAY
async def test_login_view_with_already_existing_connection(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test request timeout while trying to log in."""
cloud_client = await hass_client()
cloud.login.side_effect = AlreadyConnectedError(
details={"remote_ip_address": "127.0.0.1", "connected_at": "1"}
)
req = await cloud_client.post(
"/api/cloud/login",
json={
"email": "my_username",
"password": "my_password",
"check_connection": True,
},
)
assert cloud.login.call_args[1]["check_connection"] is True
assert req.status == HTTPStatus.CONFLICT
resp = await req.json()
assert resp == {
"code": "alreadyconnectederror",
"message": '{"remote_ip_address": "127.0.0.1", "connected_at": "1"}',
}
async def test_login_view_invalid_credentials(
cloud: MagicMock,
setup_cloud: None,