Add check_connection parameter to cloud login methods and handle AlreadyConnectedError (#138699)

This commit is contained in:
Joakim Sørensen 2025-02-19 11:19:03 +01:00 committed by GitHub
parent 38efe94def
commit 618bdba4d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 8 deletions

View File

@ -8,14 +8,15 @@ from contextlib import suppress
import dataclasses import dataclasses
from functools import wraps from functools import wraps
from http import HTTPStatus from http import HTTPStatus
import json
import logging import logging
import time import time
from typing import Any, Concatenate from typing import Any, Concatenate, cast
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
import attr import attr
from hass_nabucasa import Cloud, auth, thingtalk from hass_nabucasa import AlreadyConnectedError, Cloud, auth, thingtalk
from hass_nabucasa.const import STATE_DISCONNECTED from hass_nabucasa.const import STATE_DISCONNECTED
from hass_nabucasa.voice import TTS_VOICES from hass_nabucasa.voice import TTS_VOICES
import voluptuous as vol import voluptuous as vol
@ -64,7 +65,9 @@ from .subscription import async_subscription_info
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = { _CLOUD_ERRORS: dict[
type[Exception], tuple[HTTPStatus, Callable[[Exception], str] | str]
] = {
TimeoutError: ( TimeoutError: (
HTTPStatus.BAD_GATEWAY, HTTPStatus.BAD_GATEWAY,
"Unable to reach the Home Assistant cloud.", "Unable to reach the Home Assistant cloud.",
@ -133,6 +136,10 @@ def async_setup(hass: HomeAssistant) -> None:
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"Multi-factor authentication expired, or not started. Please try again.", "Multi-factor authentication expired, or not started. Please try again.",
), ),
AlreadyConnectedError: (
HTTPStatus.CONFLICT,
lambda x: json.dumps(cast(AlreadyConnectedError, x).details),
),
} }
) )
@ -197,7 +204,11 @@ def _process_cloud_exception(exc: Exception, where: str) -> tuple[HTTPStatus, st
for err, value_info in _CLOUD_ERRORS.items(): for err, value_info in _CLOUD_ERRORS.items():
if isinstance(exc, err): if isinstance(exc, err):
err_info = value_info status, content = value_info
err_info = (
status,
content if isinstance(content, str) else content(exc),
)
break break
if err_info is None: if err_info is None:
@ -240,6 +251,7 @@ class CloudLoginView(HomeAssistantView):
vol.All( vol.All(
{ {
vol.Required("email"): str, vol.Required("email"): str,
vol.Optional("check_connection", default=False): bool,
vol.Exclusive("password", "login"): str, vol.Exclusive("password", "login"): str,
vol.Exclusive("code", "login"): str, vol.Exclusive("code", "login"): str,
}, },
@ -258,7 +270,11 @@ class CloudLoginView(HomeAssistantView):
code = data.get("code") code = data.get("code")
if email and password: if email and password:
await cloud.login(email, password) await cloud.login(
email,
password,
check_connection=data["check_connection"],
)
else: else:
if ( if (
@ -270,7 +286,12 @@ class CloudLoginView(HomeAssistantView):
# Voluptuous should ensure that code is not None because password is # Voluptuous should ensure that code is not None because password is
assert code is not None assert code is not None
await cloud.login_verify_totp(email, code, self._mfa_tokens) await cloud.login_verify_totp(
email,
code,
self._mfa_tokens,
check_connection=data["check_connection"],
)
self._mfa_tokens = {} self._mfa_tokens = {}
self._mfa_tokens_set_time = 0 self._mfa_tokens_set_time = 0

View File

@ -145,7 +145,12 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
# Methods that we mock with a custom side effect. # Methods that we mock with a custom side effect.
async def mock_login(email: str, password: str) -> None: async def mock_login(
email: str,
password: str,
*,
check_connection: bool = False,
) -> None:
"""Mock login. """Mock login.
When called, it should call the on_start callback. When called, it should call the on_start callback.

View File

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import aiohttp import aiohttp
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
from hass_nabucasa import thingtalk from hass_nabucasa import AlreadyConnectedError, thingtalk
from hass_nabucasa.auth import ( from hass_nabucasa.auth import (
InvalidTotpCode, InvalidTotpCode,
MFARequired, MFARequired,
@ -373,9 +373,40 @@ async def test_login_view_request_timeout(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"} "/api/cloud/login", json={"email": "my_username", "password": "my_password"}
) )
assert cloud.login.call_args[1]["check_connection"] is False
assert req.status == HTTPStatus.BAD_GATEWAY assert req.status == HTTPStatus.BAD_GATEWAY
async def test_login_view_with_already_existing_connection(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test request timeout while trying to log in."""
cloud_client = await hass_client()
cloud.login.side_effect = AlreadyConnectedError(
details={"remote_ip_address": "127.0.0.1", "connected_at": "1"}
)
req = await cloud_client.post(
"/api/cloud/login",
json={
"email": "my_username",
"password": "my_password",
"check_connection": True,
},
)
assert cloud.login.call_args[1]["check_connection"] is True
assert req.status == HTTPStatus.CONFLICT
resp = await req.json()
assert resp == {
"code": "alreadyconnectederror",
"message": '{"remote_ip_address": "127.0.0.1", "connected_at": "1"}',
}
async def test_login_view_invalid_credentials( async def test_login_view_invalid_credentials(
cloud: MagicMock, cloud: MagicMock,
setup_cloud: None, setup_cloud: None,