mirror of
https://github.com/home-assistant/core.git
synced 2025-07-07 13:27:09 +00:00
Rewrite re-auth mechanism in Synology DSM integration (#54298)
This commit is contained in:
parent
0d1412ea17
commit
2f3a11f930
@ -26,7 +26,7 @@ from synology_dsm.exceptions import (
|
||||
SynologyDSMRequestException,
|
||||
)
|
||||
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
ATTR_ATTRIBUTION,
|
||||
CONF_HOST,
|
||||
@ -40,7 +40,7 @@ from homeassistant.const import (
|
||||
CONF_VERIFY_SSL,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers import device_registry, entity_registry
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.device_registry import (
|
||||
@ -193,27 +193,14 @@ async def async_setup_entry( # noqa: C901
|
||||
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
|
||||
else:
|
||||
details = EXCEPTION_UNKNOWN
|
||||
_LOGGER.debug(
|
||||
"Reauthentication for DSM '%s' needed - reason: %s",
|
||||
entry.unique_id,
|
||||
details,
|
||||
)
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={
|
||||
"source": SOURCE_REAUTH,
|
||||
"data": {**entry.data},
|
||||
EXCEPTION_DETAILS: details,
|
||||
},
|
||||
)
|
||||
)
|
||||
return False
|
||||
raise ConfigEntryAuthFailed(f"reason: {details}") from err
|
||||
except (SynologyDSMLoginFailedException, SynologyDSMRequestException) as err:
|
||||
_LOGGER.debug(
|
||||
"Unable to connect to DSM '%s' during setup: %s", entry.unique_id, err
|
||||
)
|
||||
raise ConfigEntryNotReady from err
|
||||
if err.args[0] and isinstance(err.args[0], dict):
|
||||
# pylint: disable=no-member
|
||||
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
|
||||
else:
|
||||
details = EXCEPTION_UNKNOWN
|
||||
raise ConfigEntryNotReady(details) from err
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
hass.data[DOMAIN][entry.unique_id] = {
|
||||
|
@ -46,7 +46,6 @@ from .const import (
|
||||
DEFAULT_USE_SSL,
|
||||
DEFAULT_VERIFY_SSL,
|
||||
DOMAIN,
|
||||
EXCEPTION_DETAILS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -58,11 +57,11 @@ def _discovery_schema_with_defaults(discovery_info: DiscoveryInfoType) -> vol.Sc
|
||||
return vol.Schema(_ordered_shared_schema(discovery_info))
|
||||
|
||||
|
||||
def _reauth_schema_with_defaults(user_input: dict[str, Any]) -> vol.Schema:
|
||||
def _reauth_schema() -> vol.Schema:
|
||||
return vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_USERNAME, default=user_input.get(CONF_USERNAME, "")): str,
|
||||
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD, "")): str,
|
||||
vol.Required(CONF_USERNAME): str,
|
||||
vol.Required(CONF_PASSWORD): str,
|
||||
}
|
||||
)
|
||||
|
||||
@ -113,8 +112,9 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
self.reauth_conf: dict[str, Any] = {}
|
||||
self.reauth_reason: str | None = None
|
||||
|
||||
async def _show_setup_form(
|
||||
def _show_form(
|
||||
self,
|
||||
step_id: str,
|
||||
user_input: dict[str, Any] | None = None,
|
||||
errors: dict[str, str] | None = None,
|
||||
) -> FlowResult:
|
||||
@ -123,19 +123,15 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
user_input = {}
|
||||
|
||||
description_placeholders = {}
|
||||
data_schema = {}
|
||||
|
||||
if self.discovered_conf:
|
||||
if step_id == "link":
|
||||
user_input.update(self.discovered_conf)
|
||||
step_id = "link"
|
||||
data_schema = _discovery_schema_with_defaults(user_input)
|
||||
description_placeholders = self.discovered_conf
|
||||
elif self.reauth_conf:
|
||||
user_input.update(self.reauth_conf)
|
||||
step_id = "reauth"
|
||||
data_schema = _reauth_schema_with_defaults(user_input)
|
||||
description_placeholders = {EXCEPTION_DETAILS: self.reauth_reason}
|
||||
else:
|
||||
step_id = "user"
|
||||
elif step_id == "reauth_confirm":
|
||||
data_schema = _reauth_schema()
|
||||
elif step_id == "user":
|
||||
data_schema = _user_schema_with_defaults(user_input)
|
||||
|
||||
return self.async_show_form(
|
||||
@ -145,27 +141,10 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
description_placeholders=description_placeholders,
|
||||
)
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
async def async_validate_input_create_entry(
|
||||
self, user_input: dict[str, Any], step_id: str
|
||||
) -> FlowResult:
|
||||
"""Handle a flow initiated by the user."""
|
||||
errors = {}
|
||||
|
||||
if user_input is None:
|
||||
return await self._show_setup_form(user_input, None)
|
||||
|
||||
if self.discovered_conf:
|
||||
user_input.update(self.discovered_conf)
|
||||
|
||||
if self.reauth_conf:
|
||||
self.reauth_conf.update(
|
||||
{
|
||||
CONF_USERNAME: user_input[CONF_USERNAME],
|
||||
CONF_PASSWORD: user_input[CONF_PASSWORD],
|
||||
}
|
||||
)
|
||||
user_input.update(self.reauth_conf)
|
||||
|
||||
"""Process user input and create new or update existing config entry."""
|
||||
host = user_input[CONF_HOST]
|
||||
port = user_input.get(CONF_PORT)
|
||||
username = user_input[CONF_USERNAME]
|
||||
@ -184,6 +163,7 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
host, port, username, password, use_ssl, verify_ssl, timeout=30
|
||||
)
|
||||
|
||||
errors = {}
|
||||
try:
|
||||
serial = await self.hass.async_add_executor_job(
|
||||
_login_and_fetch_syno_info, api, otp_code
|
||||
@ -207,7 +187,7 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
errors["base"] = "missing_data"
|
||||
|
||||
if errors:
|
||||
return await self._show_setup_form(user_input, errors)
|
||||
return self._show_form(step_id, user_input, errors)
|
||||
|
||||
# unique_id should be serial for services purpose
|
||||
existing_entry = await self.async_set_unique_id(serial, raise_on_progress=False)
|
||||
@ -239,6 +219,15 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
|
||||
return self.async_create_entry(title=host, data=config_data)
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle a flow initiated by the user."""
|
||||
step = "user"
|
||||
if not user_input:
|
||||
return self._show_form(step)
|
||||
return await self.async_validate_input_create_entry(user_input, step_id=step)
|
||||
|
||||
async def async_step_ssdp(self, discovery_info: DiscoveryInfoType) -> FlowResult:
|
||||
"""Handle a discovered synology_dsm."""
|
||||
parsed_url = urlparse(discovery_info[ssdp.ATTR_SSDP_LOCATION])
|
||||
@ -272,21 +261,32 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
CONF_HOST: parsed_url.hostname,
|
||||
}
|
||||
self.context["title_placeholders"] = self.discovered_conf
|
||||
return await self.async_step_user()
|
||||
return await self.async_step_link()
|
||||
|
||||
async def async_step_reauth(
|
||||
async def async_step_link(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Perform reauth upon an API authentication error."""
|
||||
self.reauth_conf = self.context.get("data", {})
|
||||
self.reauth_reason = self.context.get(EXCEPTION_DETAILS)
|
||||
if user_input is None:
|
||||
return await self.async_step_user()
|
||||
return await self.async_step_user(user_input)
|
||||
|
||||
async def async_step_link(self, user_input: dict[str, Any]) -> FlowResult:
|
||||
"""Link a config entry from discovery."""
|
||||
return await self.async_step_user(user_input)
|
||||
step = "link"
|
||||
if not user_input:
|
||||
return self._show_form(step)
|
||||
user_input = {**self.discovered_conf, **user_input}
|
||||
return await self.async_validate_input_create_entry(user_input, step_id=step)
|
||||
|
||||
async def async_step_reauth(self, data: dict[str, Any]) -> FlowResult:
|
||||
"""Perform reauth upon an API authentication error."""
|
||||
self.reauth_conf = data.copy()
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Perform reauth confirm upon an API authentication error."""
|
||||
step = "reauth_confirm"
|
||||
if not user_input:
|
||||
return self._show_form(step)
|
||||
user_input = {**self.reauth_conf, **user_input}
|
||||
return await self.async_validate_input_create_entry(user_input, step_id=step)
|
||||
|
||||
async def async_step_2sa(
|
||||
self, user_input: dict[str, Any], errors: dict[str, str] | None = None
|
||||
|
@ -30,9 +30,8 @@
|
||||
"port": "[%key:common::config_flow::data::port%]"
|
||||
}
|
||||
},
|
||||
"reauth": {
|
||||
"reauth_confirm": {
|
||||
"title": "Synology DSM [%key:common::config_flow::title::reauth%]",
|
||||
"description": "Reason: {details}",
|
||||
"data": {
|
||||
"username": "[%key:common::config_flow::data::username%]",
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
|
@ -31,12 +31,11 @@
|
||||
"description": "Do you want to setup {name} ({host})?",
|
||||
"title": "Synology DSM"
|
||||
},
|
||||
"reauth": {
|
||||
"reauth_confirm": {
|
||||
"data": {
|
||||
"password": "Password",
|
||||
"username": "Username"
|
||||
},
|
||||
"description": "Reason: {details}",
|
||||
"title": "Synology DSM Reauthenticate Integration"
|
||||
},
|
||||
"user": {
|
||||
|
@ -257,7 +257,7 @@ async def test_user_vdsm(hass: HomeAssistant, service_vdsm: MagicMock):
|
||||
|
||||
async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
||||
"""Test reauthentication."""
|
||||
MockConfigEntry(
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: HOST,
|
||||
@ -265,7 +265,8 @@ async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
||||
CONF_PASSWORD: f"{PASSWORD}_invalid",
|
||||
},
|
||||
unique_id=SERIAL,
|
||||
).add_to_hass(hass)
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.config_entries.ConfigEntries.async_reload",
|
||||
@ -276,27 +277,21 @@ async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
||||
DOMAIN,
|
||||
context={
|
||||
"source": SOURCE_REAUTH,
|
||||
"data": {
|
||||
CONF_HOST: HOST,
|
||||
CONF_USERNAME: USERNAME,
|
||||
CONF_PASSWORD: PASSWORD,
|
||||
},
|
||||
"entry_id": entry.entry_id,
|
||||
"unique_id": entry.unique_id,
|
||||
},
|
||||
data={
|
||||
CONF_HOST: HOST,
|
||||
CONF_USERNAME: USERNAME,
|
||||
CONF_PASSWORD: PASSWORD,
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "reauth"
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={
|
||||
"source": SOURCE_REAUTH,
|
||||
"data": {
|
||||
CONF_HOST: HOST,
|
||||
CONF_USERNAME: USERNAME,
|
||||
CONF_PASSWORD: PASSWORD,
|
||||
},
|
||||
},
|
||||
data={
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_USERNAME: USERNAME,
|
||||
CONF_PASSWORD: PASSWORD,
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user