Mask current password in MQTT option flow (#116098)

* Mask current password in MQTT option flow

* Update docstr

* Typo
This commit is contained in:
Jan Bouwhuis 2024-04-24 13:29:42 +02:00 committed by GitHub
parent 5aa61cb6d5
commit 18132916fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 16 deletions

View File

@ -167,6 +167,29 @@ REAUTH_SCHEMA = vol.Schema(
PWD_NOT_CHANGED = "__**password_not_changed**__" PWD_NOT_CHANGED = "__**password_not_changed**__"
@callback
def update_password_from_user_input(
entry_password: str | None, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Update the password if the entry has been updated.
As we want to avoid reflecting the stored password in the UI,
we replace the suggested value in the UI with a sentitel,
and we change it back here if it was changed.
"""
substituted_used_data = dict(user_input)
# Take out the password submitted
user_password: str | None = substituted_used_data.pop(CONF_PASSWORD, None)
# Only add the password if it has changed.
# If the sentinel password is submitted, we replace that with our current
# password from the config entry data.
password_changed = user_password is not None and user_password != PWD_NOT_CHANGED
password = user_password if password_changed else entry_password
if password is not None:
substituted_used_data[CONF_PASSWORD] = password
return substituted_used_data
class FlowHandler(ConfigFlow, domain=DOMAIN): class FlowHandler(ConfigFlow, domain=DOMAIN):
"""Handle a config flow.""" """Handle a config flow."""
@ -209,16 +232,10 @@ class FlowHandler(ConfigFlow, domain=DOMAIN):
assert self.entry is not None assert self.entry is not None
if user_input: if user_input:
password_changed = ( substituted_used_data = update_password_from_user_input(
user_password := user_input[CONF_PASSWORD] self.entry.data.get(CONF_PASSWORD), user_input
) != PWD_NOT_CHANGED )
entry_password = self.entry.data.get(CONF_PASSWORD) new_entry_data = {**self.entry.data, **substituted_used_data}
password = user_password if password_changed else entry_password
new_entry_data = {
**self.entry.data,
CONF_USERNAME: user_input.get(CONF_USERNAME),
CONF_PASSWORD: password,
}
if await self.hass.async_add_executor_job( if await self.hass.async_add_executor_job(
try_connection, try_connection,
new_entry_data, new_entry_data,
@ -350,13 +367,17 @@ class MQTTOptionsFlowHandler(OptionsFlow):
validated_user_input, validated_user_input,
errors, errors,
): ):
self.broker_config.update(
update_password_from_user_input(
self.config_entry.data.get(CONF_PASSWORD), validated_user_input
),
)
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
validated_user_input, self.broker_config,
) )
if can_connect: if can_connect:
self.broker_config.update(validated_user_input)
return await self.async_step_options() return await self.async_step_options()
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
@ -657,7 +678,9 @@ async def async_get_broker_settings(
current_broker = current_config.get(CONF_BROKER) current_broker = current_config.get(CONF_BROKER)
current_port = current_config.get(CONF_PORT, DEFAULT_PORT) current_port = current_config.get(CONF_PORT, DEFAULT_PORT)
current_user = current_config.get(CONF_USERNAME) current_user = current_config.get(CONF_USERNAME)
current_pass = current_config.get(CONF_PASSWORD) # Return the sentinel password to avoid exposure
current_entry_pass = current_config.get(CONF_PASSWORD)
current_pass = PWD_NOT_CHANGED if current_entry_pass else None
# Treat the previous post as an update of the current settings # Treat the previous post as an update of the current settings
# (if there was a basic broker setup step) # (if there was a basic broker setup step)

View File

@ -902,7 +902,7 @@ async def test_option_flow_default_suggested_values(
} }
suggested = { suggested = {
mqtt.CONF_USERNAME: "user", mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass", mqtt.CONF_PASSWORD: PWD_NOT_CHANGED,
} }
for key, value in defaults.items(): for key, value in defaults.items():
assert get_default(result["data_schema"].schema, key) == value assert get_default(result["data_schema"].schema, key) == value
@ -964,7 +964,7 @@ async def test_option_flow_default_suggested_values(
} }
suggested = { suggested = {
mqtt.CONF_USERNAME: "us3r", mqtt.CONF_USERNAME: "us3r",
mqtt.CONF_PASSWORD: "p4ss", mqtt.CONF_PASSWORD: PWD_NOT_CHANGED,
} }
for key, value in defaults.items(): for key, value in defaults.items():
assert get_default(result["data_schema"].schema, key) == value assert get_default(result["data_schema"].schema, key) == value
@ -1329,7 +1329,7 @@ async def test_try_connection_with_advanced_parameters(
} }
suggested = { suggested = {
mqtt.CONF_USERNAME: "user", mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass", mqtt.CONF_PASSWORD: PWD_NOT_CHANGED,
mqtt.CONF_TLS_INSECURE: True, mqtt.CONF_TLS_INSECURE: True,
mqtt.CONF_PROTOCOL: "3.1.1", mqtt.CONF_PROTOCOL: "3.1.1",
mqtt.CONF_TRANSPORT: "websockets", mqtt.CONF_TRANSPORT: "websockets",