Add re-auth flow for MQTT broker username and password (#116011)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Jan Bouwhuis 2024-04-23 22:26:01 +02:00 committed by GitHub
parent 0c583bb1d9
commit 31d11b2362
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 197 additions and 1 deletions

View File

@ -895,10 +895,18 @@ class MQTT:
import paho.mqtt.client as mqtt
if result_code != mqtt.CONNACK_ACCEPTED:
if result_code in (
mqtt.CONNACK_REFUSED_BAD_USERNAME_PASSWORD,
mqtt.CONNACK_REFUSED_NOT_AUTHORIZED,
):
self._should_reconnect = False
self.hass.async_create_task(self.async_disconnect())
self.config_entry.async_start_reauth(self.hass)
_LOGGER.error(
"Unable to connect to the MQTT broker: %s",
mqtt.connack_string(result_code),
)
self._async_connection_result(False)
return
self.connected = True

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Callable, Mapping
import queue
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
from types import MappingProxyType
@ -158,13 +158,23 @@ CERT_UPLOAD_SELECTOR = FileSelector(
)
KEY_UPLOAD_SELECTOR = FileSelector(FileSelectorConfig(accept=".key,application/pkcs8"))
REAUTH_SCHEMA = vol.Schema(
{
vol.Required(CONF_USERNAME): TEXT_SELECTOR,
vol.Required(CONF_PASSWORD): PASSWORD_SELECTOR,
}
)
PWD_NOT_CHANGED = "__**password_not_changed**__"
class FlowHandler(ConfigFlow, domain=DOMAIN):
"""Handle a config flow."""
VERSION = 1
entry: ConfigEntry | None
_hassio_discovery: dict[str, Any] | None = None
_reauth_config_entry: ConfigEntry | None = None
@staticmethod
@callback
@ -183,6 +193,55 @@ class FlowHandler(ConfigFlow, domain=DOMAIN):
return await self.async_step_broker()
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Handle re-authentication with Aladdin Connect."""
self.entry = self.hass.config_entries.async_get_entry(self.context["entry_id"])
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm re-authentication with MQTT broker."""
errors: dict[str, str] = {}
assert self.entry is not None
if user_input:
password_changed = (
user_password := user_input[CONF_PASSWORD]
) != PWD_NOT_CHANGED
entry_password = self.entry.data.get(CONF_PASSWORD)
password = user_password if password_changed else entry_password
new_entry_data = {
**self.entry.data,
CONF_USERNAME: user_input.get(CONF_USERNAME),
CONF_PASSWORD: password,
}
if await self.hass.async_add_executor_job(
try_connection,
new_entry_data,
):
return self.async_update_reload_and_abort(
self.entry, data=new_entry_data
)
errors["base"] = "invalid_auth"
schema = self.add_suggested_values_to_schema(
REAUTH_SCHEMA,
{
CONF_USERNAME: self.entry.data.get(CONF_USERNAME),
CONF_PASSWORD: PWD_NOT_CHANGED,
},
)
return self.async_show_form(
step_id="reauth_confirm",
data_schema=schema,
errors=errors,
)
async def async_step_broker(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:

View File

@ -68,10 +68,23 @@
"data_description": {
"discovery": "Option to enable MQTT automatic discovery."
}
},
"reauth_confirm": {
"title": "Re-authentication required with the MQTT broker",
"description": "The MQTT broker reported an authentication error. Please confirm the brokers correct usernname and password.",
"data": {
"username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]"
},
"data_description": {
"username": "[%key:component::mqtt::config::step::broker::data_description::username%]",
"password": "[%key:component::mqtt::config::step::broker::data_description::password%]"
}
}
},
"abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
},
"error": {
@ -84,6 +97,7 @@
"bad_client_cert_key": "Client certificate and private key are not a valid pair",
"bad_ws_headers": "Supply valid HTTP headers as a JSON object",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"invalid_inclusion": "The client certificate and private key must be configurered together"
}
},

View File

@ -14,6 +14,7 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components import mqtt
from homeassistant.components.hassio import HassioServiceInfo
from homeassistant.components.mqtt.config_flow import PWD_NOT_CHANGED
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@ -1060,6 +1061,102 @@ async def test_skipping_advanced_options(
assert result["step_id"] == step_id
@pytest.mark.parametrize(
("test_input", "user_input", "new_password"),
[
(
{
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_USERNAME: "username",
mqtt.CONF_PASSWORD: "verysecret",
},
{
mqtt.CONF_USERNAME: "username",
mqtt.CONF_PASSWORD: "newpassword",
},
"newpassword",
),
(
{
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_USERNAME: "username",
mqtt.CONF_PASSWORD: "verysecret",
},
{
mqtt.CONF_USERNAME: "username",
mqtt.CONF_PASSWORD: PWD_NOT_CHANGED,
},
"verysecret",
),
],
)
async def test_step_reauth(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient,
mock_try_connection: MagicMock,
mock_reload_after_entry_update: MagicMock,
test_input: dict[str, Any],
user_input: dict[str, Any],
new_password: str,
) -> None:
"""Test that the reauth step works."""
# Prepare the config entry
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
hass.config_entries.async_update_entry(
config_entry,
data=test_input,
)
await mqtt_mock_entry()
# Start reauth flow
config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
assert result["context"]["source"] == "reauth"
# Show the form
result = await hass.config_entries.flow.async_init(
mqtt.DOMAIN,
context={
"source": config_entries.SOURCE_REAUTH,
"entry_id": config_entry.entry_id,
},
data=config_entry.data,
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
# Simulate re-auth fails
mock_try_connection.return_value = False
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=user_input
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": "invalid_auth"}
# Simulate re-auth succeeds
mock_try_connection.return_value = True
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input=user_input
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert len(hass.config_entries.async_entries()) == 1
assert config_entry.data.get(mqtt.CONF_PASSWORD) == new_password
await hass.async_block_till_done()
async def test_options_user_connection_fails(
hass: HomeAssistant, mock_try_connection_time_out: MagicMock
) -> None:

View File

@ -2046,6 +2046,24 @@ async def test_logs_error_if_no_connect_broker(
)
@pytest.mark.parametrize("return_code", [4, 5])
async def test_triggers_reauth_flow_if_auth_fails(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mqtt_mock_entry: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient,
return_code: int,
) -> None:
"""Test re-auth is triggered if authentication is failing."""
await mqtt_mock_entry()
# test with rc = 4 -> CONNACK_REFUSED_NOT_AUTHORIZED and 5 -> CONNACK_REFUSED_BAD_USERNAME_PASSWORD
mqtt_client_mock.on_connect(mqtt_client_mock, None, None, return_code)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
assert flows[0]["context"]["source"] == "reauth"
@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.3)
async def test_handle_mqtt_on_callback(
hass: HomeAssistant,