Add reauthentication to Roborock (#104215)

* add reauth to roborock

* update reauth based on comments

* fix diagnostics?

* Update homeassistant/components/roborock/config_flow.py

Co-authored-by: Allen Porter <allen.porter@gmail.com>

* remove unneeded import

* fix tests coverage

---------

Co-authored-by: Allen Porter <allen.porter@gmail.com>
This commit is contained in:
Luke Lashley 2023-11-21 18:21:31 -05:00 committed by GitHub
parent aea15ee20c
commit 464270d849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 143 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
from roborock import RoborockException, RoborockInvalidCredentials
from roborock.api import RoborockApiClient from roborock.api import RoborockApiClient
from roborock.cloud_api import RoborockMqttClient from roborock.cloud_api import RoborockMqttClient
from roborock.containers import DeviceData, HomeDataDevice, UserData from roborock.containers import DeviceData, HomeDataDevice, UserData
@ -12,7 +13,7 @@ from roborock.containers import DeviceData, HomeDataDevice, UserData
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_USERNAME from homeassistant.const import CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from .const import CONF_BASE_URL, CONF_USER_DATA, DOMAIN, PLATFORMS from .const import CONF_BASE_URL, CONF_USER_DATA, DOMAIN, PLATFORMS
from .coordinator import RoborockDataUpdateCoordinator from .coordinator import RoborockDataUpdateCoordinator
@ -29,7 +30,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
user_data = UserData.from_dict(entry.data[CONF_USER_DATA]) user_data = UserData.from_dict(entry.data[CONF_USER_DATA])
api_client = RoborockApiClient(entry.data[CONF_USERNAME], entry.data[CONF_BASE_URL]) api_client = RoborockApiClient(entry.data[CONF_USERNAME], entry.data[CONF_BASE_URL])
_LOGGER.debug("Getting home data") _LOGGER.debug("Getting home data")
home_data = await api_client.get_home_data(user_data) try:
home_data = await api_client.get_home_data(user_data)
except RoborockInvalidCredentials as err:
raise ConfigEntryAuthFailed("Invalid credentials.") from err
except RoborockException as err:
raise ConfigEntryNotReady("Failed getting Roborock home_data.") from err
_LOGGER.debug("Got home data %s", home_data) _LOGGER.debug("Got home data %s", home_data)
device_map: dict[str, HomeDataDevice] = { device_map: dict[str, HomeDataDevice] = {
device.duid: device for device in home_data.devices + home_data.received_devices device.duid: device for device in home_data.devices + home_data.received_devices

View File

@ -1,6 +1,7 @@
"""Config flow for Roborock.""" """Config flow for Roborock."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
@ -16,6 +17,7 @@ from roborock.exceptions import (
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_USERNAME from homeassistant.const import CONF_USERNAME
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
@ -28,6 +30,7 @@ class RoborockFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Roborock.""" """Handle a config flow for Roborock."""
VERSION = 1 VERSION = 1
reauth_entry: ConfigEntry | None = None
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the config flow.""" """Initialize the config flow."""
@ -47,21 +50,8 @@ class RoborockFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self._username = username self._username = username
_LOGGER.debug("Requesting code for Roborock account") _LOGGER.debug("Requesting code for Roborock account")
self._client = RoborockApiClient(username) self._client = RoborockApiClient(username)
try: errors = await self._request_code()
await self._client.request_code() if not errors:
except RoborockAccountDoesNotExist:
errors["base"] = "invalid_email"
except RoborockUrlException:
errors["base"] = "unknown_url"
except RoborockInvalidEmail:
errors["base"] = "invalid_email_format"
except RoborockException as ex:
_LOGGER.exception(ex)
errors["base"] = "unknown_roborock"
except Exception as ex: # pylint: disable=broad-except
_LOGGER.exception(ex)
errors["base"] = "unknown"
else:
return await self.async_step_code() return await self.async_step_code()
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
@ -69,6 +59,25 @@ class RoborockFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
errors=errors, errors=errors,
) )
async def _request_code(self) -> dict:
assert self._client
errors: dict[str, str] = {}
try:
await self._client.request_code()
except RoborockAccountDoesNotExist:
errors["base"] = "invalid_email"
except RoborockUrlException:
errors["base"] = "unknown_url"
except RoborockInvalidEmail:
errors["base"] = "invalid_email_format"
except RoborockException as ex:
_LOGGER.exception(ex)
errors["base"] = "unknown_roborock"
except Exception as ex: # pylint: disable=broad-except
_LOGGER.exception(ex)
errors["base"] = "unknown"
return errors
async def async_step_code( async def async_step_code(
self, self,
user_input: dict[str, Any] | None = None, user_input: dict[str, Any] | None = None,
@ -91,6 +100,18 @@ class RoborockFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.exception(ex) _LOGGER.exception(ex)
errors["base"] = "unknown" errors["base"] = "unknown"
else: else:
if self.reauth_entry is not None:
self.hass.config_entries.async_update_entry(
self.reauth_entry,
data={
**self.reauth_entry.data,
CONF_USER_DATA: login_data.as_dict(),
},
)
await self.hass.config_entries.async_reload(
self.reauth_entry.entry_id
)
return self.async_abort(reason="reauth_successful")
return self._create_entry(self._client, self._username, login_data) return self._create_entry(self._client, self._username, login_data)
return self.async_show_form( return self.async_show_form(
@ -99,6 +120,27 @@ class RoborockFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
errors=errors, errors=errors,
) )
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Perform reauth upon an API authentication error."""
self._username = entry_data[CONF_USERNAME]
assert self._username
self._client = RoborockApiClient(self._username)
self.reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Confirm reauth dialog."""
errors: dict[str, str] = {}
if user_input is not None:
errors = await self._request_code()
if not errors:
return await self.async_step_code()
return self.async_show_form(step_id="reauth_confirm", errors=errors)
def _create_entry( def _create_entry(
self, client: RoborockApiClient, username: str, user_data: UserData self, client: RoborockApiClient, username: str, user_data: UserData
) -> FlowResult: ) -> FlowResult:

View File

@ -12,6 +12,10 @@
"data": { "data": {
"code": "Verification code" "code": "Verification code"
} }
},
"reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]",
"description": "The Roborock integration needs to re-authenticate your account"
} }
}, },
"error": { "error": {
@ -23,7 +27,8 @@
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]" "already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
}, },
"entity": { "entity": {

View File

@ -1,4 +1,5 @@
"""Test Roborock config flow.""" """Test Roborock config flow."""
from copy import deepcopy
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -12,9 +13,11 @@ from roborock.exceptions import (
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.roborock.const import CONF_ENTRY_CODE, DOMAIN from homeassistant.components.roborock.const import CONF_ENTRY_CODE, DOMAIN
from homeassistant.const import CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from ...common import MockConfigEntry
from .mock_data import MOCK_CONFIG, USER_DATA, USER_EMAIL from .mock_data import MOCK_CONFIG, USER_DATA, USER_EMAIL
@ -35,7 +38,7 @@ async def test_config_flow_success(
"homeassistant.components.roborock.config_flow.RoborockApiClient.request_code" "homeassistant.components.roborock.config_flow.RoborockApiClient.request_code"
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"username": USER_EMAIL} result["flow_id"], {CONF_USERNAME: USER_EMAIL}
) )
assert result["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
@ -89,7 +92,7 @@ async def test_config_flow_failures_request_code(
side_effect=request_code_side_effect, side_effect=request_code_side_effect,
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"username": USER_EMAIL} result["flow_id"], {CONF_USERNAME: USER_EMAIL}
) )
assert result["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
assert result["errors"] == request_code_errors assert result["errors"] == request_code_errors
@ -98,7 +101,7 @@ async def test_config_flow_failures_request_code(
"homeassistant.components.roborock.config_flow.RoborockApiClient.request_code" "homeassistant.components.roborock.config_flow.RoborockApiClient.request_code"
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"username": USER_EMAIL} result["flow_id"], {CONF_USERNAME: USER_EMAIL}
) )
assert result["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
@ -149,7 +152,7 @@ async def test_config_flow_failures_code_login(
"homeassistant.components.roborock.config_flow.RoborockApiClient.request_code" "homeassistant.components.roborock.config_flow.RoborockApiClient.request_code"
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"username": USER_EMAIL} result["flow_id"], {CONF_USERNAME: USER_EMAIL}
) )
assert result["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
@ -178,3 +181,39 @@ async def test_config_flow_failures_code_login(
assert result["data"] == MOCK_CONFIG assert result["data"] == MOCK_CONFIG
assert result["result"] assert result["result"]
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
async def test_reauth_flow(
hass: HomeAssistant, bypass_api_fixture, mock_roborock_entry: MockConfigEntry
) -> None:
"""Test reauth flow."""
# Start reauth
result = mock_roborock_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
# Request a new code
with patch(
"homeassistant.components.roborock.config_flow.RoborockApiClient.request_code"
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
)
# Enter a new code
assert result["step_id"] == "code"
assert result["type"] == FlowResultType.FORM
new_user_data = deepcopy(USER_DATA)
new_user_data.rriot.s = "new_password_hash"
with patch(
"homeassistant.components.roborock.config_flow.RoborockApiClient.code_login",
return_value=new_user_data,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_ENTRY_CODE: "123456"}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert mock_roborock_entry.data["user_data"]["rriot"]["s"] == "new_password_hash"

View File

@ -1,6 +1,8 @@
"""Test for Roborock init.""" """Test for Roborock init."""
from unittest.mock import patch from unittest.mock import patch
from roborock import RoborockException, RoborockInvalidCredentials
from homeassistant.components.roborock.const import DOMAIN from homeassistant.components.roborock.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -38,3 +40,30 @@ async def test_config_entry_not_ready(
): ):
await async_setup_component(hass, DOMAIN, {}) await async_setup_component(hass, DOMAIN, {})
assert mock_roborock_entry.state is ConfigEntryState.SETUP_RETRY assert mock_roborock_entry.state is ConfigEntryState.SETUP_RETRY
async def test_reauth_started(
hass: HomeAssistant, bypass_api_fixture, mock_roborock_entry: MockConfigEntry
) -> None:
"""Test reauth flow started."""
with patch(
"homeassistant.components.roborock.RoborockApiClient.get_home_data",
side_effect=RoborockInvalidCredentials(),
):
await async_setup_component(hass, DOMAIN, {})
assert mock_roborock_entry.state is ConfigEntryState.SETUP_ERROR
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
assert flows[0]["step_id"] == "reauth_confirm"
async def test_config_entry_not_ready_home_data(
hass: HomeAssistant, mock_roborock_entry: MockConfigEntry
) -> None:
"""Test that when we fail to get home data, entry retries."""
with patch(
"homeassistant.components.roborock.RoborockApiClient.get_home_data",
side_effect=RoborockException(),
):
await async_setup_component(hass, DOMAIN, {})
assert mock_roborock_entry.state is ConfigEntryState.SETUP_RETRY