mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 18:27:51 +00:00
Add check_connection parameter to cloud login methods and handle AlreadyConnectedError (#138699)
This commit is contained in:
parent
38efe94def
commit
618bdba4d3
@ -8,14 +8,15 @@ from contextlib import suppress
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Concatenate
|
from typing import Any, Concatenate, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import attr
|
import attr
|
||||||
from hass_nabucasa import Cloud, auth, thingtalk
|
from hass_nabucasa import AlreadyConnectedError, Cloud, auth, thingtalk
|
||||||
from hass_nabucasa.const import STATE_DISCONNECTED
|
from hass_nabucasa.const import STATE_DISCONNECTED
|
||||||
from hass_nabucasa.voice import TTS_VOICES
|
from hass_nabucasa.voice import TTS_VOICES
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -64,7 +65,9 @@ from .subscription import async_subscription_info
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
|
_CLOUD_ERRORS: dict[
|
||||||
|
type[Exception], tuple[HTTPStatus, Callable[[Exception], str] | str]
|
||||||
|
] = {
|
||||||
TimeoutError: (
|
TimeoutError: (
|
||||||
HTTPStatus.BAD_GATEWAY,
|
HTTPStatus.BAD_GATEWAY,
|
||||||
"Unable to reach the Home Assistant cloud.",
|
"Unable to reach the Home Assistant cloud.",
|
||||||
@ -133,6 +136,10 @@ def async_setup(hass: HomeAssistant) -> None:
|
|||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"Multi-factor authentication expired, or not started. Please try again.",
|
"Multi-factor authentication expired, or not started. Please try again.",
|
||||||
),
|
),
|
||||||
|
AlreadyConnectedError: (
|
||||||
|
HTTPStatus.CONFLICT,
|
||||||
|
lambda x: json.dumps(cast(AlreadyConnectedError, x).details),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -197,7 +204,11 @@ def _process_cloud_exception(exc: Exception, where: str) -> tuple[HTTPStatus, st
|
|||||||
|
|
||||||
for err, value_info in _CLOUD_ERRORS.items():
|
for err, value_info in _CLOUD_ERRORS.items():
|
||||||
if isinstance(exc, err):
|
if isinstance(exc, err):
|
||||||
err_info = value_info
|
status, content = value_info
|
||||||
|
err_info = (
|
||||||
|
status,
|
||||||
|
content if isinstance(content, str) else content(exc),
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
if err_info is None:
|
if err_info is None:
|
||||||
@ -240,6 +251,7 @@ class CloudLoginView(HomeAssistantView):
|
|||||||
vol.All(
|
vol.All(
|
||||||
{
|
{
|
||||||
vol.Required("email"): str,
|
vol.Required("email"): str,
|
||||||
|
vol.Optional("check_connection", default=False): bool,
|
||||||
vol.Exclusive("password", "login"): str,
|
vol.Exclusive("password", "login"): str,
|
||||||
vol.Exclusive("code", "login"): str,
|
vol.Exclusive("code", "login"): str,
|
||||||
},
|
},
|
||||||
@ -258,7 +270,11 @@ class CloudLoginView(HomeAssistantView):
|
|||||||
code = data.get("code")
|
code = data.get("code")
|
||||||
|
|
||||||
if email and password:
|
if email and password:
|
||||||
await cloud.login(email, password)
|
await cloud.login(
|
||||||
|
email,
|
||||||
|
password,
|
||||||
|
check_connection=data["check_connection"],
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
@ -270,7 +286,12 @@ class CloudLoginView(HomeAssistantView):
|
|||||||
# Voluptuous should ensure that code is not None because password is
|
# Voluptuous should ensure that code is not None because password is
|
||||||
assert code is not None
|
assert code is not None
|
||||||
|
|
||||||
await cloud.login_verify_totp(email, code, self._mfa_tokens)
|
await cloud.login_verify_totp(
|
||||||
|
email,
|
||||||
|
code,
|
||||||
|
self._mfa_tokens,
|
||||||
|
check_connection=data["check_connection"],
|
||||||
|
)
|
||||||
self._mfa_tokens = {}
|
self._mfa_tokens = {}
|
||||||
self._mfa_tokens_set_time = 0
|
self._mfa_tokens_set_time = 0
|
||||||
|
|
||||||
|
@ -145,7 +145,12 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
|
|||||||
|
|
||||||
# Methods that we mock with a custom side effect.
|
# Methods that we mock with a custom side effect.
|
||||||
|
|
||||||
async def mock_login(email: str, password: str) -> None:
|
async def mock_login(
|
||||||
|
email: str,
|
||||||
|
password: str,
|
||||||
|
*,
|
||||||
|
check_connection: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Mock login.
|
"""Mock login.
|
||||||
|
|
||||||
When called, it should call the on_start callback.
|
When called, it should call the on_start callback.
|
||||||
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
from hass_nabucasa import thingtalk
|
from hass_nabucasa import AlreadyConnectedError, thingtalk
|
||||||
from hass_nabucasa.auth import (
|
from hass_nabucasa.auth import (
|
||||||
InvalidTotpCode,
|
InvalidTotpCode,
|
||||||
MFARequired,
|
MFARequired,
|
||||||
@ -373,9 +373,40 @@ async def test_login_view_request_timeout(
|
|||||||
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert cloud.login.call_args[1]["check_connection"] is False
|
||||||
|
|
||||||
assert req.status == HTTPStatus.BAD_GATEWAY
|
assert req.status == HTTPStatus.BAD_GATEWAY
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_with_already_existing_connection(
|
||||||
|
cloud: MagicMock,
|
||||||
|
setup_cloud: None,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test request timeout while trying to log in."""
|
||||||
|
cloud_client = await hass_client()
|
||||||
|
cloud.login.side_effect = AlreadyConnectedError(
|
||||||
|
details={"remote_ip_address": "127.0.0.1", "connected_at": "1"}
|
||||||
|
)
|
||||||
|
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login",
|
||||||
|
json={
|
||||||
|
"email": "my_username",
|
||||||
|
"password": "my_password",
|
||||||
|
"check_connection": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cloud.login.call_args[1]["check_connection"] is True
|
||||||
|
assert req.status == HTTPStatus.CONFLICT
|
||||||
|
resp = await req.json()
|
||||||
|
assert resp == {
|
||||||
|
"code": "alreadyconnectederror",
|
||||||
|
"message": '{"remote_ip_address": "127.0.0.1", "connected_at": "1"}',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_login_view_invalid_credentials(
|
async def test_login_view_invalid_credentials(
|
||||||
cloud: MagicMock,
|
cloud: MagicMock,
|
||||||
setup_cloud: None,
|
setup_cloud: None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user