Update Ecovacs config_flow to support self-hosted instances (#108944)

* Update Ecovacs config_flow to support  self-hosted instances

* Selfhosted should add their instance urls

* Improve config flow

* Improve and adapt to version bump

* Add test for self-hosted

* Make ruff happy

* Update homeassistant/components/ecovacs/strings.json

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

* Implement suggestions

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Implement suggestions

* Remove ,

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Resch 2024-01-31 13:17:00 +01:00 committed by GitHub
parent f77e4b24e6
commit 4bad88b42c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 596 additions and 72 deletions

View File

@ -2,39 +2,81 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import ssl
from typing import Any, cast from typing import Any, cast
from urllib.parse import urlparse
from aiohttp import ClientError from aiohttp import ClientError
from deebot_client.authentication import Authenticator, create_rest_config from deebot_client.authentication import Authenticator, create_rest_config
from deebot_client.exceptions import InvalidAuthenticationError from deebot_client.const import UNDEFINED, UndefinedType
from deebot_client.exceptions import InvalidAuthenticationError, MqttError
from deebot_client.mqtt_client import MqttClient, create_mqtt_config
from deebot_client.util import md5 from deebot_client.util import md5
from deebot_client.util.continents import COUNTRIES_TO_CONTINENTS, get_continent from deebot_client.util.continents import COUNTRIES_TO_CONTINENTS, get_continent
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigFlow from homeassistant.config_entries import ConfigFlow
from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_COUNTRY, CONF_MODE, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import aiohttp_client, selector from homeassistant.helpers import aiohttp_client, selector
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.loader import async_get_issue_tracker from homeassistant.loader import async_get_issue_tracker
from homeassistant.util.ssl import get_default_no_verify_context
from .const import CONF_CONTINENT, DOMAIN from .const import (
CONF_CONTINENT,
CONF_OVERRIDE_MQTT_URL,
CONF_OVERRIDE_REST_URL,
CONF_VERIFY_MQTT_CERTIFICATE,
DOMAIN,
InstanceMode,
)
from .util import get_client_device_id from .util import get_client_device_id
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def _validate_url(
value: str,
field_name: str,
schema_list: set[str],
) -> dict[str, str]:
"""Validate an URL and return error dictionary."""
if urlparse(value).scheme not in schema_list:
return {field_name: f"invalid_url_schema_{field_name}"}
try:
vol.Schema(vol.Url())(value)
except vol.Invalid:
return {field_name: "invalid_url"}
return {}
async def _validate_input( async def _validate_input(
hass: HomeAssistant, user_input: dict[str, Any] hass: HomeAssistant, user_input: dict[str, Any]
) -> dict[str, str]: ) -> dict[str, str]:
"""Validate user input.""" """Validate user input."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
if rest_url := user_input.get(CONF_OVERRIDE_REST_URL):
errors.update(
_validate_url(rest_url, CONF_OVERRIDE_REST_URL, {"http", "https"})
)
if mqtt_url := user_input.get(CONF_OVERRIDE_MQTT_URL):
errors.update(
_validate_url(mqtt_url, CONF_OVERRIDE_MQTT_URL, {"mqtt", "mqtts"})
)
if errors:
return errors
device_id = get_client_device_id()
country = user_input[CONF_COUNTRY]
rest_config = create_rest_config( rest_config = create_rest_config(
aiohttp_client.async_get_clientsession(hass), aiohttp_client.async_get_clientsession(hass),
device_id=get_client_device_id(), device_id=device_id,
country=user_input[CONF_COUNTRY], country=country,
override_rest_url=rest_url,
) )
authenticator = Authenticator( authenticator = Authenticator(
@ -54,6 +96,34 @@ async def _validate_input(
_LOGGER.exception("Unexpected exception during login") _LOGGER.exception("Unexpected exception during login")
errors["base"] = "unknown" errors["base"] = "unknown"
if errors:
return errors
ssl_context: UndefinedType | ssl.SSLContext = UNDEFINED
if not user_input.get(CONF_VERIFY_MQTT_CERTIFICATE, True) and mqtt_url:
ssl_context = get_default_no_verify_context()
mqtt_config = create_mqtt_config(
device_id=device_id,
country=country,
override_mqtt_url=mqtt_url,
ssl_context=ssl_context,
)
client = MqttClient(mqtt_config, authenticator)
cannot_connect_field = CONF_OVERRIDE_MQTT_URL if mqtt_url else "base"
try:
await client.verify_config()
except MqttError:
_LOGGER.debug("Cannot connect", exc_info=True)
errors[cannot_connect_field] = "cannot_connect"
except InvalidAuthenticationError:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception during mqtt connection verification")
errors["base"] = "unknown"
return errors return errors
@ -62,10 +132,42 @@ class EcovacsConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
_mode: InstanceMode = InstanceMode.CLOUD
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Handle the initial step.""" """Handle the initial step."""
if not self.show_advanced_options:
return await self.async_step_auth()
if user_input:
self._mode = user_input[CONF_MODE]
return await self.async_step_auth()
return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{
vol.Required(
CONF_MODE, default=InstanceMode.CLOUD
): selector.SelectSelector(
selector.SelectSelectorConfig(
options=list(InstanceMode),
translation_key="installation_mode",
mode=selector.SelectSelectorMode.DROPDOWN,
)
)
}
),
last_step=False,
)
async def async_step_auth(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the auth step."""
errors = {} errors = {}
if user_input: if user_input:
@ -78,30 +180,41 @@ class EcovacsConfigFlow(ConfigFlow, domain=DOMAIN):
title=user_input[CONF_USERNAME], data=user_input title=user_input[CONF_USERNAME], data=user_input
) )
return self.async_show_form( schema = {
step_id="user",
data_schema=self.add_suggested_values_to_schema(
data_schema=vol.Schema(
{
vol.Required(CONF_USERNAME): selector.TextSelector( vol.Required(CONF_USERNAME): selector.TextSelector(
selector.TextSelectorConfig( selector.TextSelectorConfig(type=selector.TextSelectorType.TEXT)
type=selector.TextSelectorType.TEXT
)
), ),
vol.Required(CONF_PASSWORD): selector.TextSelector( vol.Required(CONF_PASSWORD): selector.TextSelector(
selector.TextSelectorConfig( selector.TextSelectorConfig(type=selector.TextSelectorType.PASSWORD)
type=selector.TextSelectorType.PASSWORD
)
), ),
vol.Required(CONF_COUNTRY): selector.CountrySelector(), vol.Required(CONF_COUNTRY): selector.CountrySelector(),
} }
if self._mode == InstanceMode.SELF_HOSTED:
schema.update(
{
vol.Required(CONF_OVERRIDE_REST_URL): selector.TextSelector(
selector.TextSelectorConfig(type=selector.TextSelectorType.URL)
), ),
suggested_values=user_input vol.Required(CONF_OVERRIDE_MQTT_URL): selector.TextSelector(
or { selector.TextSelectorConfig(type=selector.TextSelectorType.URL)
),
}
)
if errors:
schema[vol.Optional(CONF_VERIFY_MQTT_CERTIFICATE, default=True)] = bool
if not user_input:
user_input = {
CONF_COUNTRY: self.hass.config.country, CONF_COUNTRY: self.hass.config.country,
}, }
return self.async_show_form(
step_id="auth",
data_schema=self.add_suggested_values_to_schema(
data_schema=vol.Schema(schema), suggested_values=user_input
), ),
errors=errors, errors=errors,
last_step=True,
) )
async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult: async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
@ -181,7 +294,7 @@ class EcovacsConfigFlow(ConfigFlow, domain=DOMAIN):
# Remove the continent from the user input as it is not needed anymore # Remove the continent from the user input as it is not needed anymore
user_input.pop(CONF_CONTINENT) user_input.pop(CONF_CONTINENT)
try: try:
result = await self.async_step_user(user_input) result = await self.async_step_auth(user_input)
except AbortFlow as ex: except AbortFlow as ex:
if ex.reason == "already_configured": if ex.reason == "already_configured":
create_repair() create_repair()

View File

@ -1,12 +1,24 @@
"""Ecovacs constants.""" """Ecovacs constants."""
from enum import StrEnum
from deebot_client.events import LifeSpan from deebot_client.events import LifeSpan
DOMAIN = "ecovacs" DOMAIN = "ecovacs"
CONF_CONTINENT = "continent" CONF_CONTINENT = "continent"
CONF_OVERRIDE_REST_URL = "override_rest_url"
CONF_OVERRIDE_MQTT_URL = "override_mqtt_url"
CONF_VERIFY_MQTT_CERTIFICATE = "verify_mqtt_certificate"
SUPPORTED_LIFESPANS = ( SUPPORTED_LIFESPANS = (
LifeSpan.BRUSH, LifeSpan.BRUSH,
LifeSpan.FILTER, LifeSpan.FILTER,
LifeSpan.SIDE_BRUSH, LifeSpan.SIDE_BRUSH,
) )
class InstanceMode(StrEnum):
"""Instance mode."""
CLOUD = "cloud"
SELF_HOSTED = "self_hosted"

View File

@ -3,10 +3,12 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
import ssl
from typing import Any from typing import Any
from deebot_client.api_client import ApiClient from deebot_client.api_client import ApiClient
from deebot_client.authentication import Authenticator, create_rest_config from deebot_client.authentication import Authenticator, create_rest_config
from deebot_client.const import UNDEFINED, UndefinedType
from deebot_client.device import Device from deebot_client.device import Device
from deebot_client.exceptions import DeebotError, InvalidAuthenticationError from deebot_client.exceptions import DeebotError, InvalidAuthenticationError
from deebot_client.models import DeviceInfo from deebot_client.models import DeviceInfo
@ -19,7 +21,13 @@ from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
from homeassistant.helpers import aiohttp_client from homeassistant.helpers import aiohttp_client
from homeassistant.util.ssl import get_default_no_verify_context
from .const import (
CONF_OVERRIDE_MQTT_URL,
CONF_OVERRIDE_REST_URL,
CONF_VERIFY_MQTT_CERTIFICATE,
)
from .util import get_client_device_id from .util import get_client_device_id
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -42,15 +50,24 @@ class EcovacsController:
aiohttp_client.async_get_clientsession(self._hass), aiohttp_client.async_get_clientsession(self._hass),
device_id=self._device_id, device_id=self._device_id,
country=country, country=country,
override_rest_url=config.get(CONF_OVERRIDE_REST_URL),
), ),
config[CONF_USERNAME], config[CONF_USERNAME],
md5(config[CONF_PASSWORD]), md5(config[CONF_PASSWORD]),
) )
self._api_client = ApiClient(self._authenticator) self._api_client = ApiClient(self._authenticator)
mqtt_url = config.get(CONF_OVERRIDE_MQTT_URL)
ssl_context: UndefinedType | ssl.SSLContext = UNDEFINED
if not config.get(CONF_VERIFY_MQTT_CERTIFICATE, True) and mqtt_url:
ssl_context = get_default_no_verify_context()
self._mqtt = MqttClient( self._mqtt = MqttClient(
create_mqtt_config( create_mqtt_config(
device_id=self._device_id, device_id=self._device_id,
country=country, country=country,
override_mqtt_url=mqtt_url,
ssl_context=ssl_context,
), ),
self._authenticator, self._authenticator,
) )

View File

@ -8,10 +8,16 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_NAME, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import DOMAIN from .const import CONF_OVERRIDE_MQTT_URL, CONF_OVERRIDE_REST_URL, DOMAIN
from .controller import EcovacsController from .controller import EcovacsController
REDACT_CONFIG = {CONF_USERNAME, CONF_PASSWORD, "title"} REDACT_CONFIG = {
CONF_USERNAME,
CONF_PASSWORD,
"title",
CONF_OVERRIDE_MQTT_URL,
CONF_OVERRIDE_REST_URL,
}
REDACT_DEVICE = {"did", CONF_NAME, "homeId"} REDACT_DEVICE = {"did", CONF_NAME, "homeId"}

View File

@ -6,14 +6,32 @@
"error": { "error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"invalid_url": "Invalid URL",
"invalid_url_schema_override_rest_url": "Invalid REST URL scheme.\nThe URL should start with `http://` or `https://`.",
"invalid_url_schema_override_mqtt_url": "Invalid MQTT URL scheme.\nThe URL should start with `mqtt://` or `mqtts://`.",
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]"
}, },
"step": { "step": {
"user": { "auth": {
"data": { "data": {
"country": "Country", "country": "Country",
"override_rest_url": "REST URL",
"override_mqtt_url": "MQTT URL",
"password": "[%key:common::config_flow::data::password%]", "password": "[%key:common::config_flow::data::password%]",
"username": "[%key:common::config_flow::data::username%]" "username": "[%key:common::config_flow::data::username%]",
"verify_mqtt_certificate": "Verify MQTT SSL certificate"
},
"data_description": {
"override_rest_url": "Enter the REST URL of your self-hosted instance including the scheme (http/https).",
"override_mqtt_url": "Enter the MQTT URL of your self-hosted instance including the scheme (mqtt/mqtts)."
}
},
"user": {
"data": {
"mode": "[%key:common::config_flow::data::mode%]"
},
"data_description": {
"mode": "Select the mode you want to use to connect to Ecovacs. If you are unsure, select 'Cloud'.\n\nSelect 'Self-hosted' only if you have a working self-hosted instance."
} }
} }
} }
@ -157,5 +175,13 @@
"title": "The Ecovacs YAML configuration import failed", "title": "The Ecovacs YAML configuration import failed",
"description": "Configuring Ecovacs using YAML is being removed but there is an unexpected continent specified in the YAML configuration.\n\nFrom the given country, the continent '{continent}' is expected. Change the continent and restart Home Assistant to try again or remove the Ecovacs YAML configuration from your configuration.yaml file and continue to [set up the integration]({url}) manually.\n\nIf the contintent '{continent}' is not applicable, please open an issue on [GitHub]({github_issue_url})." "description": "Configuring Ecovacs using YAML is being removed but there is an unexpected continent specified in the YAML configuration.\n\nFrom the given country, the continent '{continent}' is expected. Change the continent and restart Home Assistant to try again or remove the Ecovacs YAML configuration from your configuration.yaml file and continue to [set up the integration]({url}) manually.\n\nIf the contintent '{continent}' is not applicable, please open an issue on [GitHub]({github_issue_url})."
} }
},
"selector": {
"installation_mode": {
"options": {
"cloud": "Cloud",
"self_hosted": "Self-hosted"
}
}
} }
} }

View File

@ -12,10 +12,10 @@ import pytest
from homeassistant.components.ecovacs import PLATFORMS from homeassistant.components.ecovacs import PLATFORMS
from homeassistant.components.ecovacs.const import DOMAIN from homeassistant.components.ecovacs.const import DOMAIN
from homeassistant.components.ecovacs.controller import EcovacsController from homeassistant.components.ecovacs.controller import EcovacsController
from homeassistant.const import Platform from homeassistant.const import CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import VALID_ENTRY_DATA from .const import VALID_ENTRY_DATA_CLOUD
from tests.common import MockConfigEntry, load_json_object_fixture from tests.common import MockConfigEntry, load_json_object_fixture
@ -30,12 +30,18 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@pytest.fixture @pytest.fixture
def mock_config_entry() -> MockConfigEntry: def mock_config_entry_data() -> dict[str, Any]:
"""Return the default mocked config entry data."""
return VALID_ENTRY_DATA_CLOUD
@pytest.fixture
def mock_config_entry(mock_config_entry_data: dict[str, Any]) -> MockConfigEntry:
"""Return the default mocked config entry.""" """Return the default mocked config entry."""
return MockConfigEntry( return MockConfigEntry(
title="username", title=mock_config_entry_data[CONF_USERNAME],
domain=DOMAIN, domain=DOMAIN,
data=VALID_ENTRY_DATA, data=mock_config_entry_data,
) )
@ -62,7 +68,7 @@ def mock_authenticator(device_fixture: str) -> Generator[Mock, None, None]:
load_json_object_fixture(f"devices/{device_fixture}/device.json", DOMAIN) load_json_object_fixture(f"devices/{device_fixture}/device.json", DOMAIN)
] ]
def post_authenticated( async def post_authenticated(
path: str, path: str,
json: dict[str, Any], json: dict[str, Any],
*, *,
@ -89,8 +95,11 @@ def mock_mqtt_client(mock_authenticator: Mock) -> Mock:
with patch( with patch(
"homeassistant.components.ecovacs.controller.MqttClient", "homeassistant.components.ecovacs.controller.MqttClient",
autospec=True, autospec=True,
) as mock_mqtt_client: ) as mock, patch(
client = mock_mqtt_client.return_value "homeassistant.components.ecovacs.config_flow.MqttClient",
new=mock,
):
client = mock.return_value
client._authenticator = mock_authenticator client._authenticator = mock_authenticator
client.subscribe.return_value = lambda: None client.subscribe.return_value = lambda: None
yield client yield client

View File

@ -1,13 +1,28 @@
"""Test ecovacs constants.""" """Test ecovacs constants."""
from homeassistant.components.ecovacs.const import CONF_CONTINENT from homeassistant.components.ecovacs.const import (
CONF_CONTINENT,
CONF_OVERRIDE_MQTT_URL,
CONF_OVERRIDE_REST_URL,
CONF_VERIFY_MQTT_CERTIFICATE,
)
from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_COUNTRY, CONF_PASSWORD, CONF_USERNAME
VALID_ENTRY_DATA = { VALID_ENTRY_DATA_CLOUD = {
CONF_USERNAME: "username", CONF_USERNAME: "username@cloud",
CONF_PASSWORD: "password", CONF_PASSWORD: "password",
CONF_COUNTRY: "IT", CONF_COUNTRY: "IT",
} }
IMPORT_DATA = VALID_ENTRY_DATA | {CONF_CONTINENT: "EU"} VALID_ENTRY_DATA_SELF_HOSTED = VALID_ENTRY_DATA_CLOUD | {
CONF_USERNAME: "username@self-hosted",
CONF_OVERRIDE_REST_URL: "http://localhost:8000",
CONF_OVERRIDE_MQTT_URL: "mqtt://localhost:1883",
}
VALID_ENTRY_DATA_SELF_HOSTED_WITH_VALIDATE_CERT = VALID_ENTRY_DATA_SELF_HOSTED | {
CONF_VERIFY_MQTT_CERTIFICATE: True,
}
IMPORT_DATA = VALID_ENTRY_DATA_CLOUD | {CONF_CONTINENT: "EU"}

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_diagnostics # name: test_diagnostics[username@cloud]
dict({ dict({
'config': dict({ 'config': dict({
'data': dict({ 'data': dict({
@ -48,3 +48,54 @@
]), ]),
}) })
# --- # ---
# name: test_diagnostics[username@self-hosted]
dict({
'config': dict({
'data': dict({
'country': 'IT',
'override_mqtt_url': '**REDACTED**',
'override_rest_url': '**REDACTED**',
'password': '**REDACTED**',
'username': '**REDACTED**',
}),
'disabled_by': None,
'domain': 'ecovacs',
'minor_version': 1,
'options': dict({
}),
'pref_disable_new_entities': False,
'pref_disable_polling': False,
'source': 'user',
'title': '**REDACTED**',
'unique_id': None,
'version': 1,
}),
'devices': list([
dict({
'UILogicId': 'DX_9G',
'class': 'yna5xi',
'company': 'eco-ng',
'deviceName': 'DEEBOT OZMO 950 Series',
'did': '**REDACTED**',
'homeSort': 9999,
'icon': 'https://portal-ww.ecouser.net/api/pim/file/get/606278df4a84d700082b39f1',
'materialNo': '110-1820-0101',
'model': 'DX9G',
'name': '**REDACTED**',
'nick': 'Ozmo 950',
'otaUpgrade': dict({
}),
'pid': '5c19a91ca1e6ee000178224a',
'product_category': 'DEEBOT',
'resource': 'upQ6',
'service': dict({
'jmq': 'jmq-ngiot-eu.dc.ww.ecouser.net',
'mqs': 'api-ngiot.dc-as.ww.ecouser.net',
}),
'status': 1,
}),
]),
'legacy_devices': list([
]),
})
# ---

View File

@ -1,86 +1,307 @@
"""Test Ecovacs config flow.""" """Test Ecovacs config flow."""
from collections.abc import Awaitable, Callable
import ssl
from typing import Any from typing import Any
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, Mock, patch
from aiohttp import ClientError from aiohttp import ClientError
from deebot_client.exceptions import InvalidAuthenticationError from deebot_client.exceptions import InvalidAuthenticationError, MqttError
from deebot_client.mqtt_client import create_mqtt_config
import pytest import pytest
from homeassistant.components.ecovacs.const import DOMAIN from homeassistant.components.ecovacs.const import (
CONF_CONTINENT,
CONF_OVERRIDE_MQTT_URL,
CONF_OVERRIDE_REST_URL,
CONF_VERIFY_MQTT_CERTIFICATE,
DOMAIN,
InstanceMode,
)
from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER
from homeassistant.const import CONF_USERNAME from homeassistant.const import CONF_COUNTRY, CONF_MODE, CONF_USERNAME
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import issue_registry as ir from homeassistant.helpers import issue_registry as ir
from .const import IMPORT_DATA, VALID_ENTRY_DATA from .const import (
IMPORT_DATA,
VALID_ENTRY_DATA_CLOUD,
VALID_ENTRY_DATA_SELF_HOSTED,
VALID_ENTRY_DATA_SELF_HOSTED_WITH_VALIDATE_CERT,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
_USER_STEP_SELF_HOSTED = {CONF_MODE: InstanceMode.SELF_HOSTED}
async def _test_user_flow(hass: HomeAssistant) -> dict[str, Any]: _TEST_FN_AUTH_ARG = "user_input_auth"
_TEST_FN_USER_ARG = "user_input_user"
async def _test_user_flow(
hass: HomeAssistant,
user_input_auth: dict[str, Any],
) -> dict[str, Any]:
"""Test config flow.""" """Test config flow."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": SOURCE_USER}, context={"source": SOURCE_USER},
) )
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "auth"
assert not result["errors"]
return await hass.config_entries.flow.async_configure( return await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
user_input=VALID_ENTRY_DATA, user_input=user_input_auth,
) )
async def _test_user_flow_show_advanced_options(
hass: HomeAssistant,
*,
user_input_auth: dict[str, Any],
user_input_user: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Test config flow."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER, "show_advanced_options": True},
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
assert not result["errors"]
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=user_input_user or {},
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "auth"
assert not result["errors"]
return await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=user_input_auth,
)
@pytest.mark.parametrize(
("test_fn", "test_fn_args", "entry_data"),
[
(
_test_user_flow_show_advanced_options,
{_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD},
VALID_ENTRY_DATA_CLOUD,
),
(
_test_user_flow_show_advanced_options,
{
_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_SELF_HOSTED,
_TEST_FN_USER_ARG: _USER_STEP_SELF_HOSTED,
},
VALID_ENTRY_DATA_SELF_HOSTED,
),
(
_test_user_flow,
{_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD},
VALID_ENTRY_DATA_CLOUD,
),
],
ids=["advanced_cloud", "advanced_self_hosted", "cloud"],
)
async def test_user_flow( async def test_user_flow(
hass: HomeAssistant, hass: HomeAssistant,
mock_setup_entry: AsyncMock, mock_setup_entry: AsyncMock,
mock_authenticator_authenticate: AsyncMock, mock_authenticator_authenticate: AsyncMock,
mock_mqtt_client: Mock,
test_fn: Callable[[HomeAssistant, dict[str, Any]], Awaitable[dict[str, Any]]]
| Callable[
[HomeAssistant, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
],
test_fn_args: dict[str, Any],
entry_data: dict[str, Any],
) -> None: ) -> None:
"""Test the user config flow.""" """Test the user config flow."""
result = await _test_user_flow(hass) result = await test_fn(
hass,
**test_fn_args,
)
assert result["type"] == FlowResultType.CREATE_ENTRY assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["title"] == VALID_ENTRY_DATA[CONF_USERNAME] assert result["title"] == entry_data[CONF_USERNAME]
assert result["data"] == VALID_ENTRY_DATA assert result["data"] == entry_data
mock_setup_entry.assert_called() mock_setup_entry.assert_called()
mock_authenticator_authenticate.assert_called() mock_authenticator_authenticate.assert_called()
mock_mqtt_client.verify_config.assert_called()
def _cannot_connect_error(user_input: dict[str, Any]) -> str:
field = "base"
if CONF_OVERRIDE_MQTT_URL in user_input:
field = CONF_OVERRIDE_MQTT_URL
return {field: "cannot_connect"}
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "reason"), ("side_effect_mqtt", "errors_mqtt"),
[
(MqttError, _cannot_connect_error),
(InvalidAuthenticationError, lambda _: {"base": "invalid_auth"}),
(Exception, lambda _: {"base": "unknown"}),
],
ids=["cannot_connect", "invalid_auth", "unknown"],
)
@pytest.mark.parametrize(
("side_effect_rest", "reason_rest"),
[ [
(ClientError, "cannot_connect"), (ClientError, "cannot_connect"),
(InvalidAuthenticationError, "invalid_auth"), (InvalidAuthenticationError, "invalid_auth"),
(Exception, "unknown"), (Exception, "unknown"),
], ],
ids=["cannot_connect", "invalid_auth", "unknown"],
) )
async def test_user_flow_error( @pytest.mark.parametrize(
("test_fn", "test_fn_args", "entry_data"),
[
(
_test_user_flow_show_advanced_options,
{_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD},
VALID_ENTRY_DATA_CLOUD,
),
(
_test_user_flow_show_advanced_options,
{
_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_SELF_HOSTED,
_TEST_FN_USER_ARG: _USER_STEP_SELF_HOSTED,
},
VALID_ENTRY_DATA_SELF_HOSTED_WITH_VALIDATE_CERT,
),
(
_test_user_flow,
{_TEST_FN_AUTH_ARG: VALID_ENTRY_DATA_CLOUD},
VALID_ENTRY_DATA_CLOUD,
),
],
ids=["advanced_cloud", "advanced_self_hosted", "cloud"],
)
async def test_user_flow_raise_error(
hass: HomeAssistant, hass: HomeAssistant,
side_effect: Exception,
reason: str,
mock_setup_entry: AsyncMock, mock_setup_entry: AsyncMock,
mock_authenticator_authenticate: AsyncMock, mock_authenticator_authenticate: AsyncMock,
mock_mqtt_client: Mock,
side_effect_rest: Exception,
reason_rest: str,
side_effect_mqtt: Exception,
errors_mqtt: Callable[[dict[str, Any]], str],
test_fn: Callable[[HomeAssistant, dict[str, Any]], Awaitable[dict[str, Any]]]
| Callable[
[HomeAssistant, dict[str, Any], dict[str, Any]], Awaitable[dict[str, Any]]
],
test_fn_args: dict[str, Any],
entry_data: dict[str, Any],
) -> None: ) -> None:
"""Test handling invalid connection.""" """Test handling error on library calls."""
user_input_auth = test_fn_args[_TEST_FN_AUTH_ARG]
mock_authenticator_authenticate.side_effect = side_effect # Authenticator raises error
mock_authenticator_authenticate.side_effect = side_effect_rest
result = await _test_user_flow(hass) result = await test_fn(
hass,
**test_fn_args,
)
assert result["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user" assert result["step_id"] == "auth"
assert result["errors"] == {"base": reason} assert result["errors"] == {"base": reason_rest}
mock_authenticator_authenticate.assert_called() mock_authenticator_authenticate.assert_called()
mock_mqtt_client.verify_config.assert_not_called()
mock_setup_entry.assert_not_called() mock_setup_entry.assert_not_called()
mock_authenticator_authenticate.reset_mock(side_effect=True) mock_authenticator_authenticate.reset_mock(side_effect=True)
# MQTT raises error
mock_mqtt_client.verify_config.side_effect = side_effect_mqtt
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
user_input=VALID_ENTRY_DATA, user_input=user_input_auth,
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "auth"
assert result["errors"] == errors_mqtt(user_input_auth)
mock_authenticator_authenticate.assert_called()
mock_mqtt_client.verify_config.assert_called()
mock_setup_entry.assert_not_called()
mock_authenticator_authenticate.reset_mock(side_effect=True)
mock_mqtt_client.verify_config.reset_mock(side_effect=True)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=user_input_auth,
) )
assert result["type"] == FlowResultType.CREATE_ENTRY assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["title"] == VALID_ENTRY_DATA[CONF_USERNAME] assert result["title"] == entry_data[CONF_USERNAME]
assert result["data"] == VALID_ENTRY_DATA assert result["data"] == entry_data
mock_setup_entry.assert_called() mock_setup_entry.assert_called()
mock_authenticator_authenticate.assert_called() mock_authenticator_authenticate.assert_called()
mock_mqtt_client.verify_config.assert_called()
async def test_user_flow_self_hosted_error(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_authenticator_authenticate: AsyncMock,
mock_mqtt_client: Mock,
) -> None:
"""Test handling selfhosted errors and custom ssl context."""
result = await _test_user_flow_show_advanced_options(
hass,
user_input_auth=VALID_ENTRY_DATA_SELF_HOSTED
| {
CONF_OVERRIDE_REST_URL: "bla://localhost:8000",
CONF_OVERRIDE_MQTT_URL: "mqtt://",
},
user_input_user=_USER_STEP_SELF_HOSTED,
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "auth"
assert result["errors"] == {
CONF_OVERRIDE_REST_URL: "invalid_url_schema_override_rest_url",
CONF_OVERRIDE_MQTT_URL: "invalid_url",
}
mock_authenticator_authenticate.assert_not_called()
mock_mqtt_client.verify_config.assert_not_called()
mock_setup_entry.assert_not_called()
# Check that the schema includes select box to disable ssl verification of mqtt
assert CONF_VERIFY_MQTT_CERTIFICATE in result["data_schema"].schema
data = VALID_ENTRY_DATA_SELF_HOSTED | {CONF_VERIFY_MQTT_CERTIFICATE: False}
with patch(
"homeassistant.components.ecovacs.config_flow.create_mqtt_config",
wraps=create_mqtt_config,
) as mock_create_mqtt_config:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=data,
)
mock_create_mqtt_config.assert_called_once()
ssl_context = mock_create_mqtt_config.call_args[1]["ssl_context"]
assert isinstance(ssl_context, ssl.SSLContext)
assert ssl_context.verify_mode == ssl.CERT_NONE
assert ssl_context.check_hostname is False
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["title"] == data[CONF_USERNAME]
assert result["data"] == data
mock_setup_entry.assert_called()
mock_authenticator_authenticate.assert_called()
mock_mqtt_client.verify_config.assert_called()
async def test_import_flow( async def test_import_flow(
@ -88,6 +309,7 @@ async def test_import_flow(
issue_registry: ir.IssueRegistry, issue_registry: ir.IssueRegistry,
mock_setup_entry: AsyncMock, mock_setup_entry: AsyncMock,
mock_authenticator_authenticate: AsyncMock, mock_authenticator_authenticate: AsyncMock,
mock_mqtt_client: Mock,
) -> None: ) -> None:
"""Test importing yaml config.""" """Test importing yaml config."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -98,17 +320,18 @@ async def test_import_flow(
mock_authenticator_authenticate.assert_called() mock_authenticator_authenticate.assert_called()
assert result["type"] == FlowResultType.CREATE_ENTRY assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["title"] == VALID_ENTRY_DATA[CONF_USERNAME] assert result["title"] == VALID_ENTRY_DATA_CLOUD[CONF_USERNAME]
assert result["data"] == VALID_ENTRY_DATA assert result["data"] == VALID_ENTRY_DATA_CLOUD
assert (HOMEASSISTANT_DOMAIN, f"deprecated_yaml_{DOMAIN}") in issue_registry.issues assert (HOMEASSISTANT_DOMAIN, f"deprecated_yaml_{DOMAIN}") in issue_registry.issues
mock_setup_entry.assert_called() mock_setup_entry.assert_called()
mock_mqtt_client.verify_config.assert_called()
async def test_import_flow_already_configured( async def test_import_flow_already_configured(
hass: HomeAssistant, issue_registry: ir.IssueRegistry hass: HomeAssistant, issue_registry: ir.IssueRegistry
) -> None: ) -> None:
"""Test importing yaml config where entry already configured.""" """Test importing yaml config where entry already configured."""
entry = MockConfigEntry(domain=DOMAIN, data=VALID_ENTRY_DATA) entry = MockConfigEntry(domain=DOMAIN, data=VALID_ENTRY_DATA_CLOUD)
entry.add_to_hass(hass) entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -121,6 +344,7 @@ async def test_import_flow_already_configured(
assert (HOMEASSISTANT_DOMAIN, f"deprecated_yaml_{DOMAIN}") in issue_registry.issues assert (HOMEASSISTANT_DOMAIN, f"deprecated_yaml_{DOMAIN}") in issue_registry.issues
@pytest.mark.parametrize("show_advanced_options", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "reason"), ("side_effect", "reason"),
[ [
@ -131,17 +355,22 @@ async def test_import_flow_already_configured(
) )
async def test_import_flow_error( async def test_import_flow_error(
hass: HomeAssistant, hass: HomeAssistant,
side_effect: Exception,
reason: str,
issue_registry: ir.IssueRegistry, issue_registry: ir.IssueRegistry,
mock_authenticator_authenticate: AsyncMock, mock_authenticator_authenticate: AsyncMock,
mock_mqtt_client: Mock,
side_effect: Exception,
reason: str,
show_advanced_options: bool,
) -> None: ) -> None:
"""Test handling invalid connection.""" """Test handling invalid connection."""
mock_authenticator_authenticate.side_effect = side_effect mock_authenticator_authenticate.side_effect = side_effect
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": SOURCE_IMPORT}, context={
"source": SOURCE_IMPORT,
"show_advanced_options": show_advanced_options,
},
data=IMPORT_DATA.copy(), data=IMPORT_DATA.copy(),
) )
assert result["type"] == FlowResultType.ABORT assert result["type"] == FlowResultType.ABORT
@ -151,3 +380,38 @@ async def test_import_flow_error(
f"deprecated_yaml_import_issue_{reason}", f"deprecated_yaml_import_issue_{reason}",
) in issue_registry.issues ) in issue_registry.issues
mock_authenticator_authenticate.assert_called() mock_authenticator_authenticate.assert_called()
@pytest.mark.parametrize("show_advanced_options", [True, False])
@pytest.mark.parametrize(
("reason", "user_input"),
[
("invalid_country_length", IMPORT_DATA | {CONF_COUNTRY: "too_long"}),
("invalid_country_length", IMPORT_DATA | {CONF_COUNTRY: "a"}), # too short
("invalid_continent_length", IMPORT_DATA | {CONF_CONTINENT: "too_long"}),
("invalid_continent_length", IMPORT_DATA | {CONF_CONTINENT: "a"}), # too short
("continent_not_match", IMPORT_DATA | {CONF_CONTINENT: "AA"}),
],
)
async def test_import_flow_invalid_data(
hass: HomeAssistant,
issue_registry: ir.IssueRegistry,
reason: str,
user_input: dict[str, Any],
show_advanced_options: bool,
) -> None:
"""Test handling invalid connection."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": SOURCE_IMPORT,
"show_advanced_options": show_advanced_options,
},
data=user_input,
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == reason
assert (
DOMAIN,
f"deprecated_yaml_import_issue_{reason}",
) in issue_registry.issues

View File

@ -1,15 +1,24 @@
"""Tests for diagnostics data.""" """Tests for diagnostics data."""
import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from syrupy.filters import props from syrupy.filters import props
from homeassistant.const import CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import VALID_ENTRY_DATA_CLOUD, VALID_ENTRY_DATA_SELF_HOSTED
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.diagnostics import get_diagnostics_for_config_entry from tests.components.diagnostics import get_diagnostics_for_config_entry
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@pytest.mark.parametrize(
"mock_config_entry_data",
[VALID_ENTRY_DATA_CLOUD, VALID_ENTRY_DATA_SELF_HOSTED],
ids=lambda data: data[CONF_USERNAME],
)
async def test_diagnostics( async def test_diagnostics(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,

View File

@ -87,6 +87,7 @@ async def test_async_setup_import(
config_entries_expected: int, config_entries_expected: int,
mock_setup_entry: AsyncMock, mock_setup_entry: AsyncMock,
mock_authenticator_authenticate: AsyncMock, mock_authenticator_authenticate: AsyncMock,
mock_mqtt_client: Mock,
) -> None: ) -> None:
"""Test async_setup config import.""" """Test async_setup config import."""
assert len(hass.config_entries.async_entries(DOMAIN)) == 0 assert len(hass.config_entries.async_entries(DOMAIN)) == 0
@ -95,6 +96,7 @@ async def test_async_setup_import(
assert len(hass.config_entries.async_entries(DOMAIN)) == config_entries_expected assert len(hass.config_entries.async_entries(DOMAIN)) == config_entries_expected
assert mock_setup_entry.call_count == config_entries_expected assert mock_setup_entry.call_count == config_entries_expected
assert mock_authenticator_authenticate.call_count == config_entries_expected assert mock_authenticator_authenticate.call_count == config_entries_expected
assert mock_mqtt_client.verify_config.call_count == config_entries_expected
async def test_devices_in_dr( async def test_devices_in_dr(