mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 05:47:10 +00:00
Rewrite re-auth mechanism in Synology DSM integration (#54298)
This commit is contained in:
parent
0d1412ea17
commit
2f3a11f930
@ -26,7 +26,7 @@ from synology_dsm.exceptions import (
|
|||||||
SynologyDSMRequestException,
|
SynologyDSMRequestException,
|
||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ATTRIBUTION,
|
ATTR_ATTRIBUTION,
|
||||||
CONF_HOST,
|
CONF_HOST,
|
||||||
@ -40,7 +40,7 @@ from homeassistant.const import (
|
|||||||
CONF_VERIFY_SSL,
|
CONF_VERIFY_SSL,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||||
from homeassistant.helpers import device_registry, entity_registry
|
from homeassistant.helpers import device_registry, entity_registry
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.device_registry import (
|
from homeassistant.helpers.device_registry import (
|
||||||
@ -193,27 +193,14 @@ async def async_setup_entry( # noqa: C901
|
|||||||
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
|
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
|
||||||
else:
|
else:
|
||||||
details = EXCEPTION_UNKNOWN
|
details = EXCEPTION_UNKNOWN
|
||||||
_LOGGER.debug(
|
raise ConfigEntryAuthFailed(f"reason: {details}") from err
|
||||||
"Reauthentication for DSM '%s' needed - reason: %s",
|
|
||||||
entry.unique_id,
|
|
||||||
details,
|
|
||||||
)
|
|
||||||
hass.async_create_task(
|
|
||||||
hass.config_entries.flow.async_init(
|
|
||||||
DOMAIN,
|
|
||||||
context={
|
|
||||||
"source": SOURCE_REAUTH,
|
|
||||||
"data": {**entry.data},
|
|
||||||
EXCEPTION_DETAILS: details,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
except (SynologyDSMLoginFailedException, SynologyDSMRequestException) as err:
|
except (SynologyDSMLoginFailedException, SynologyDSMRequestException) as err:
|
||||||
_LOGGER.debug(
|
if err.args[0] and isinstance(err.args[0], dict):
|
||||||
"Unable to connect to DSM '%s' during setup: %s", entry.unique_id, err
|
# pylint: disable=no-member
|
||||||
)
|
details = err.args[0].get(EXCEPTION_DETAILS, EXCEPTION_UNKNOWN)
|
||||||
raise ConfigEntryNotReady from err
|
else:
|
||||||
|
details = EXCEPTION_UNKNOWN
|
||||||
|
raise ConfigEntryNotReady(details) from err
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})
|
hass.data.setdefault(DOMAIN, {})
|
||||||
hass.data[DOMAIN][entry.unique_id] = {
|
hass.data[DOMAIN][entry.unique_id] = {
|
||||||
|
@ -46,7 +46,6 @@ from .const import (
|
|||||||
DEFAULT_USE_SSL,
|
DEFAULT_USE_SSL,
|
||||||
DEFAULT_VERIFY_SSL,
|
DEFAULT_VERIFY_SSL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
EXCEPTION_DETAILS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -58,11 +57,11 @@ def _discovery_schema_with_defaults(discovery_info: DiscoveryInfoType) -> vol.Sc
|
|||||||
return vol.Schema(_ordered_shared_schema(discovery_info))
|
return vol.Schema(_ordered_shared_schema(discovery_info))
|
||||||
|
|
||||||
|
|
||||||
def _reauth_schema_with_defaults(user_input: dict[str, Any]) -> vol.Schema:
|
def _reauth_schema() -> vol.Schema:
|
||||||
return vol.Schema(
|
return vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required(CONF_USERNAME, default=user_input.get(CONF_USERNAME, "")): str,
|
vol.Required(CONF_USERNAME): str,
|
||||||
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD, "")): str,
|
vol.Required(CONF_PASSWORD): str,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -113,8 +112,9 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
self.reauth_conf: dict[str, Any] = {}
|
self.reauth_conf: dict[str, Any] = {}
|
||||||
self.reauth_reason: str | None = None
|
self.reauth_reason: str | None = None
|
||||||
|
|
||||||
async def _show_setup_form(
|
def _show_form(
|
||||||
self,
|
self,
|
||||||
|
step_id: str,
|
||||||
user_input: dict[str, Any] | None = None,
|
user_input: dict[str, Any] | None = None,
|
||||||
errors: dict[str, str] | None = None,
|
errors: dict[str, str] | None = None,
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
@ -123,19 +123,15 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
user_input = {}
|
user_input = {}
|
||||||
|
|
||||||
description_placeholders = {}
|
description_placeholders = {}
|
||||||
|
data_schema = {}
|
||||||
|
|
||||||
if self.discovered_conf:
|
if step_id == "link":
|
||||||
user_input.update(self.discovered_conf)
|
user_input.update(self.discovered_conf)
|
||||||
step_id = "link"
|
|
||||||
data_schema = _discovery_schema_with_defaults(user_input)
|
data_schema = _discovery_schema_with_defaults(user_input)
|
||||||
description_placeholders = self.discovered_conf
|
description_placeholders = self.discovered_conf
|
||||||
elif self.reauth_conf:
|
elif step_id == "reauth_confirm":
|
||||||
user_input.update(self.reauth_conf)
|
data_schema = _reauth_schema()
|
||||||
step_id = "reauth"
|
elif step_id == "user":
|
||||||
data_schema = _reauth_schema_with_defaults(user_input)
|
|
||||||
description_placeholders = {EXCEPTION_DETAILS: self.reauth_reason}
|
|
||||||
else:
|
|
||||||
step_id = "user"
|
|
||||||
data_schema = _user_schema_with_defaults(user_input)
|
data_schema = _user_schema_with_defaults(user_input)
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
@ -145,27 +141,10 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
description_placeholders=description_placeholders,
|
description_placeholders=description_placeholders,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_validate_input_create_entry(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any], step_id: str
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Handle a flow initiated by the user."""
|
"""Process user input and create new or update existing config entry."""
|
||||||
errors = {}
|
|
||||||
|
|
||||||
if user_input is None:
|
|
||||||
return await self._show_setup_form(user_input, None)
|
|
||||||
|
|
||||||
if self.discovered_conf:
|
|
||||||
user_input.update(self.discovered_conf)
|
|
||||||
|
|
||||||
if self.reauth_conf:
|
|
||||||
self.reauth_conf.update(
|
|
||||||
{
|
|
||||||
CONF_USERNAME: user_input[CONF_USERNAME],
|
|
||||||
CONF_PASSWORD: user_input[CONF_PASSWORD],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
user_input.update(self.reauth_conf)
|
|
||||||
|
|
||||||
host = user_input[CONF_HOST]
|
host = user_input[CONF_HOST]
|
||||||
port = user_input.get(CONF_PORT)
|
port = user_input.get(CONF_PORT)
|
||||||
username = user_input[CONF_USERNAME]
|
username = user_input[CONF_USERNAME]
|
||||||
@ -184,6 +163,7 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
host, port, username, password, use_ssl, verify_ssl, timeout=30
|
host, port, username, password, use_ssl, verify_ssl, timeout=30
|
||||||
)
|
)
|
||||||
|
|
||||||
|
errors = {}
|
||||||
try:
|
try:
|
||||||
serial = await self.hass.async_add_executor_job(
|
serial = await self.hass.async_add_executor_job(
|
||||||
_login_and_fetch_syno_info, api, otp_code
|
_login_and_fetch_syno_info, api, otp_code
|
||||||
@ -207,7 +187,7 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
errors["base"] = "missing_data"
|
errors["base"] = "missing_data"
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return await self._show_setup_form(user_input, errors)
|
return self._show_form(step_id, user_input, errors)
|
||||||
|
|
||||||
# unique_id should be serial for services purpose
|
# unique_id should be serial for services purpose
|
||||||
existing_entry = await self.async_set_unique_id(serial, raise_on_progress=False)
|
existing_entry = await self.async_set_unique_id(serial, raise_on_progress=False)
|
||||||
@ -239,6 +219,15 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
|
|
||||||
return self.async_create_entry(title=host, data=config_data)
|
return self.async_create_entry(title=host, data=config_data)
|
||||||
|
|
||||||
|
async def async_step_user(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle a flow initiated by the user."""
|
||||||
|
step = "user"
|
||||||
|
if not user_input:
|
||||||
|
return self._show_form(step)
|
||||||
|
return await self.async_validate_input_create_entry(user_input, step_id=step)
|
||||||
|
|
||||||
async def async_step_ssdp(self, discovery_info: DiscoveryInfoType) -> FlowResult:
|
async def async_step_ssdp(self, discovery_info: DiscoveryInfoType) -> FlowResult:
|
||||||
"""Handle a discovered synology_dsm."""
|
"""Handle a discovered synology_dsm."""
|
||||||
parsed_url = urlparse(discovery_info[ssdp.ATTR_SSDP_LOCATION])
|
parsed_url = urlparse(discovery_info[ssdp.ATTR_SSDP_LOCATION])
|
||||||
@ -272,21 +261,32 @@ class SynologyDSMFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||||||
CONF_HOST: parsed_url.hostname,
|
CONF_HOST: parsed_url.hostname,
|
||||||
}
|
}
|
||||||
self.context["title_placeholders"] = self.discovered_conf
|
self.context["title_placeholders"] = self.discovered_conf
|
||||||
return await self.async_step_user()
|
return await self.async_step_link()
|
||||||
|
|
||||||
async def async_step_reauth(
|
async def async_step_link(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Perform reauth upon an API authentication error."""
|
|
||||||
self.reauth_conf = self.context.get("data", {})
|
|
||||||
self.reauth_reason = self.context.get(EXCEPTION_DETAILS)
|
|
||||||
if user_input is None:
|
|
||||||
return await self.async_step_user()
|
|
||||||
return await self.async_step_user(user_input)
|
|
||||||
|
|
||||||
async def async_step_link(self, user_input: dict[str, Any]) -> FlowResult:
|
|
||||||
"""Link a config entry from discovery."""
|
"""Link a config entry from discovery."""
|
||||||
return await self.async_step_user(user_input)
|
step = "link"
|
||||||
|
if not user_input:
|
||||||
|
return self._show_form(step)
|
||||||
|
user_input = {**self.discovered_conf, **user_input}
|
||||||
|
return await self.async_validate_input_create_entry(user_input, step_id=step)
|
||||||
|
|
||||||
|
async def async_step_reauth(self, data: dict[str, Any]) -> FlowResult:
|
||||||
|
"""Perform reauth upon an API authentication error."""
|
||||||
|
self.reauth_conf = data.copy()
|
||||||
|
return await self.async_step_reauth_confirm()
|
||||||
|
|
||||||
|
async def async_step_reauth_confirm(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Perform reauth confirm upon an API authentication error."""
|
||||||
|
step = "reauth_confirm"
|
||||||
|
if not user_input:
|
||||||
|
return self._show_form(step)
|
||||||
|
user_input = {**self.reauth_conf, **user_input}
|
||||||
|
return await self.async_validate_input_create_entry(user_input, step_id=step)
|
||||||
|
|
||||||
async def async_step_2sa(
|
async def async_step_2sa(
|
||||||
self, user_input: dict[str, Any], errors: dict[str, str] | None = None
|
self, user_input: dict[str, Any], errors: dict[str, str] | None = None
|
||||||
|
@ -30,9 +30,8 @@
|
|||||||
"port": "[%key:common::config_flow::data::port%]"
|
"port": "[%key:common::config_flow::data::port%]"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"reauth": {
|
"reauth_confirm": {
|
||||||
"title": "Synology DSM [%key:common::config_flow::title::reauth%]",
|
"title": "Synology DSM [%key:common::config_flow::title::reauth%]",
|
||||||
"description": "Reason: {details}",
|
|
||||||
"data": {
|
"data": {
|
||||||
"username": "[%key:common::config_flow::data::username%]",
|
"username": "[%key:common::config_flow::data::username%]",
|
||||||
"password": "[%key:common::config_flow::data::password%]"
|
"password": "[%key:common::config_flow::data::password%]"
|
||||||
|
@ -31,12 +31,11 @@
|
|||||||
"description": "Do you want to setup {name} ({host})?",
|
"description": "Do you want to setup {name} ({host})?",
|
||||||
"title": "Synology DSM"
|
"title": "Synology DSM"
|
||||||
},
|
},
|
||||||
"reauth": {
|
"reauth_confirm": {
|
||||||
"data": {
|
"data": {
|
||||||
"password": "Password",
|
"password": "Password",
|
||||||
"username": "Username"
|
"username": "Username"
|
||||||
},
|
},
|
||||||
"description": "Reason: {details}",
|
|
||||||
"title": "Synology DSM Reauthenticate Integration"
|
"title": "Synology DSM Reauthenticate Integration"
|
||||||
},
|
},
|
||||||
"user": {
|
"user": {
|
||||||
|
@ -257,7 +257,7 @@ async def test_user_vdsm(hass: HomeAssistant, service_vdsm: MagicMock):
|
|||||||
|
|
||||||
async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
||||||
"""Test reauthentication."""
|
"""Test reauthentication."""
|
||||||
MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={
|
data={
|
||||||
CONF_HOST: HOST,
|
CONF_HOST: HOST,
|
||||||
@ -265,7 +265,8 @@ async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
|||||||
CONF_PASSWORD: f"{PASSWORD}_invalid",
|
CONF_PASSWORD: f"{PASSWORD}_invalid",
|
||||||
},
|
},
|
||||||
unique_id=SERIAL,
|
unique_id=SERIAL,
|
||||||
).add_to_hass(hass)
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.config_entries.ConfigEntries.async_reload",
|
"homeassistant.config_entries.ConfigEntries.async_reload",
|
||||||
@ -276,27 +277,21 @@ async def test_reauth(hass: HomeAssistant, service: MagicMock):
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
context={
|
context={
|
||||||
"source": SOURCE_REAUTH,
|
"source": SOURCE_REAUTH,
|
||||||
"data": {
|
"entry_id": entry.entry_id,
|
||||||
|
"unique_id": entry.unique_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
CONF_HOST: HOST,
|
CONF_HOST: HOST,
|
||||||
CONF_USERNAME: USERNAME,
|
CONF_USERNAME: USERNAME,
|
||||||
CONF_PASSWORD: PASSWORD,
|
CONF_PASSWORD: PASSWORD,
|
||||||
},
|
},
|
||||||
},
|
|
||||||
)
|
)
|
||||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
assert result["step_id"] == "reauth"
|
assert result["step_id"] == "reauth_confirm"
|
||||||
|
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_configure(
|
||||||
DOMAIN,
|
result["flow_id"],
|
||||||
context={
|
{
|
||||||
"source": SOURCE_REAUTH,
|
|
||||||
"data": {
|
|
||||||
CONF_HOST: HOST,
|
|
||||||
CONF_USERNAME: USERNAME,
|
|
||||||
CONF_PASSWORD: PASSWORD,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
data={
|
|
||||||
CONF_USERNAME: USERNAME,
|
CONF_USERNAME: USERNAME,
|
||||||
CONF_PASSWORD: PASSWORD,
|
CONF_PASSWORD: PASSWORD,
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user