From 9356bf1a8e6df72a498a9b3517e1a8f2a0abefcb Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Thu, 3 Mar 2022 21:40:15 +0100 Subject: [PATCH] Fix MQTT config flow with advanced parameters (#67556) * Fix MQTT config flow with advanced parameters * Add test --- homeassistant/components/mqtt/__init__.py | 106 +++++++++++-------- homeassistant/components/mqtt/config_flow.py | 29 ++--- homeassistant/components/mqtt/const.py | 7 ++ tests/components/mqtt/test_config_flow.py | 97 ++++++++++++++++- 4 files changed, 177 insertions(+), 62 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 090d9cdfa73..bbe773f141c 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -70,11 +70,16 @@ from .const import ( ATTR_TOPIC, CONF_BIRTH_MESSAGE, CONF_BROKER, + CONF_CERTIFICATE, + CONF_CLIENT_CERT, + CONF_CLIENT_KEY, CONF_COMMAND_TOPIC, CONF_ENCODING, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, + CONF_TLS_INSECURE, + CONF_TLS_VERSION, CONF_TOPIC, CONF_WILL_MESSAGE, DATA_MQTT_CONFIG, @@ -89,6 +94,7 @@ from .const import ( DOMAIN, MQTT_CONNECTED, MQTT_DISCONNECTED, + PROTOCOL_31, PROTOCOL_311, ) from .discovery import LAST_DISCOVERY @@ -113,13 +119,6 @@ SERVICE_DUMP = "dump" CONF_DISCOVERY_PREFIX = "discovery_prefix" CONF_KEEPALIVE = "keepalive" -CONF_CERTIFICATE = "certificate" -CONF_CLIENT_KEY = "client_key" -CONF_CLIENT_CERT = "client_cert" -CONF_TLS_INSECURE = "tls_insecure" -CONF_TLS_VERSION = "tls_version" - -PROTOCOL_31 = "3.1" DEFAULT_PORT = 1883 DEFAULT_KEEPALIVE = 60 @@ -751,6 +750,58 @@ class Subscription: encoding: str | None = attr.ib(default="utf-8") +class MqttClientSetup: + """Helper class to setup the paho mqtt client from config.""" + + # We don't import on the top because some integrations + # should be able to optionally rely on MQTT. + import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel + + def __init__(self, config: ConfigType) -> None: + """Initialize the MQTT client setup helper.""" + + if config[CONF_PROTOCOL] == PROTOCOL_31: + proto = self.mqtt.MQTTv31 + else: + proto = self.mqtt.MQTTv311 + + if (client_id := config.get(CONF_CLIENT_ID)) is None: + # PAHO MQTT relies on the MQTT server to generate random client IDs. + # However, that feature is not mandatory so we generate our own. + client_id = self.mqtt.base62(uuid.uuid4().int, padding=22) + self._client = self.mqtt.Client(client_id, protocol=proto) + + # Enable logging + self._client.enable_logger() + + username = config.get(CONF_USERNAME) + password = config.get(CONF_PASSWORD) + if username is not None: + self._client.username_pw_set(username, password) + + if (certificate := config.get(CONF_CERTIFICATE)) == "auto": + certificate = certifi.where() + + client_key = config.get(CONF_CLIENT_KEY) + client_cert = config.get(CONF_CLIENT_CERT) + tls_insecure = config.get(CONF_TLS_INSECURE) + if certificate is not None: + self._client.tls_set( + certificate, + certfile=client_cert, + keyfile=client_key, + tls_version=ssl.PROTOCOL_TLS, + ) + + if tls_insecure is not None: + self._client.tls_insecure_set(tls_insecure) + + @property + def client(self) -> mqtt.Client: + """Return the paho MQTT client.""" + return self._client + + class MQTT: """Home Assistant MQTT client.""" @@ -815,46 +866,7 @@ class MQTT: def init_client(self): """Initialize paho client.""" - # We don't import on the top because some integrations - # should be able to optionally rely on MQTT. - import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel - - if self.conf[CONF_PROTOCOL] == PROTOCOL_31: - proto: int = mqtt.MQTTv31 - else: - proto = mqtt.MQTTv311 - - if (client_id := self.conf.get(CONF_CLIENT_ID)) is None: - # PAHO MQTT relies on the MQTT server to generate random client IDs. - # However, that feature is not mandatory so we generate our own. - client_id = mqtt.base62(uuid.uuid4().int, padding=22) - self._mqttc = mqtt.Client(client_id, protocol=proto) - - # Enable logging - self._mqttc.enable_logger() - - username = self.conf.get(CONF_USERNAME) - password = self.conf.get(CONF_PASSWORD) - if username is not None: - self._mqttc.username_pw_set(username, password) - - if (certificate := self.conf.get(CONF_CERTIFICATE)) == "auto": - certificate = certifi.where() - - client_key = self.conf.get(CONF_CLIENT_KEY) - client_cert = self.conf.get(CONF_CLIENT_CERT) - tls_insecure = self.conf.get(CONF_TLS_INSECURE) - if certificate is not None: - self._mqttc.tls_set( - certificate, - certfile=client_cert, - keyfile=client_key, - tls_version=ssl.PROTOCOL_TLS, - ) - - if tls_insecure is not None: - self._mqttc.tls_insecure_set(tls_insecure) - + self._mqttc = MqttClientSetup(self.conf).client self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_message = self._mqtt_on_message diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 3f93e50829a..99e7e9718d0 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -17,6 +17,7 @@ from homeassistant.const import ( ) from homeassistant.data_entry_flow import FlowResult +from . import MqttClientSetup from .const import ( ATTR_PAYLOAD, ATTR_QOS, @@ -62,6 +63,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if user_input is not None: can_connect = await self.hass.async_add_executor_job( try_connection, + self.hass, user_input[CONF_BROKER], user_input[CONF_PORT], user_input.get(CONF_USERNAME), @@ -102,6 +104,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): data = self._hassio_discovery can_connect = await self.hass.async_add_executor_job( try_connection, + self.hass, data[CONF_HOST], data[CONF_PORT], data.get(CONF_USERNAME), @@ -152,6 +155,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): if user_input is not None: can_connect = await self.hass.async_add_executor_job( try_connection, + self.hass, user_input[CONF_BROKER], user_input[CONF_PORT], user_input.get(CONF_USERNAME), @@ -313,25 +317,24 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): ) -def try_connection(broker, port, username, password, protocol="3.1"): +def try_connection(hass, broker, port, username, password, protocol="3.1"): """Test if we can connect to an MQTT broker.""" - # pylint: disable-next=import-outside-toplevel - import paho.mqtt.client as mqtt - - if protocol == "3.1": - proto = mqtt.MQTTv31 - else: - proto = mqtt.MQTTv311 - - client = mqtt.Client(protocol=proto) - if username and password: - client.username_pw_set(username, password) + # Get the config from configuration.yaml + yaml_config = hass.data.get(DATA_MQTT_CONFIG, {}) + entry_config = { + CONF_BROKER: broker, + CONF_PORT: port, + CONF_USERNAME: username, + CONF_PASSWORD: password, + CONF_PROTOCOL: protocol, + } + client = MqttClientSetup({**yaml_config, **entry_config}).client result = queue.Queue(maxsize=1) def on_connect(client_, userdata, flags, result_code): """Handle connection result.""" - result.put(result_code == mqtt.CONNACK_ACCEPTED) + result.put(result_code == MqttClientSetup.mqtt.CONNACK_ACCEPTED) client.on_connect = on_connect diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index f04348ee002..69865733763 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -22,6 +22,12 @@ CONF_STATE_VALUE_TEMPLATE = "state_value_template" CONF_TOPIC = "topic" CONF_WILL_MESSAGE = "will_message" +CONF_CERTIFICATE = "certificate" +CONF_CLIENT_KEY = "client_key" +CONF_CLIENT_CERT = "client_cert" +CONF_TLS_INSECURE = "tls_insecure" +CONF_TLS_VERSION = "tls_version" + DATA_MQTT_CONFIG = "mqtt_config" DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed" @@ -56,4 +62,5 @@ MQTT_DISCONNECTED = "mqtt_disconnected" PAYLOAD_EMPTY_JSON = "{}" PAYLOAD_NONE = "None" +PROTOCOL_31 = "3.1" PROTOCOL_311 = "3.1.1" diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index d9aab02e821..88c6137bf94 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -3,8 +3,9 @@ from unittest.mock import patch import pytest import voluptuous as vol +import yaml -from homeassistant import config_entries, data_entry_flow +from homeassistant import config as hass_config, config_entries, data_entry_flow from homeassistant.components import mqtt from homeassistant.components.hassio import HassioServiceInfo from homeassistant.core import HomeAssistant @@ -151,7 +152,7 @@ async def test_manual_config_set( "discovery": True, } # Check we tried the connection, with precedence for config entry settings - mock_try_connection.assert_called_once_with("127.0.0.1", 1883, None, None) + mock_try_connection.assert_called_once_with(hass, "127.0.0.1", 1883, None, None) # Check config entry got setup assert len(mock_finish_setup.mock_calls) == 1 @@ -642,3 +643,95 @@ async def test_options_bad_will_message_fails(hass, mock_try_connection): mqtt.CONF_BROKER: "test-broker", mqtt.CONF_PORT: 1234, } + + +async def test_try_connection_with_advanced_parameters( + hass, mock_try_connection_success, tmp_path +): + """Test config flow with advanced parameters from config.""" + # Mock certificate files + certfile = tmp_path / "cert.pem" + certfile.write_text("## mock certificate file ##") + keyfile = tmp_path / "key.pem" + keyfile.write_text("## mock key file ##") + config = { + "certificate": "auto", + "tls_insecure": True, + "client_cert": certfile, + "client_key": keyfile, + } + new_yaml_config_file = tmp_path / "configuration.yaml" + new_yaml_config = yaml.dump({mqtt.DOMAIN: config}) + new_yaml_config_file.write_text(new_yaml_config) + assert new_yaml_config_file.read_text() == new_yaml_config + + with patch.object(hass_config, "YAML_CONFIG_FILE", new_yaml_config_file): + await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config}) + await hass.async_block_till_done() + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + config_entry.data = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + mqtt.CONF_USERNAME: "user", + mqtt.CONF_PASSWORD: "pass", + mqtt.CONF_DISCOVERY: True, + mqtt.CONF_BIRTH_MESSAGE: { + mqtt.ATTR_TOPIC: "ha_state/online", + mqtt.ATTR_PAYLOAD: "online", + mqtt.ATTR_QOS: 1, + mqtt.ATTR_RETAIN: True, + }, + mqtt.CONF_WILL_MESSAGE: { + mqtt.ATTR_TOPIC: "ha_state/offline", + mqtt.ATTR_PAYLOAD: "offline", + mqtt.ATTR_QOS: 2, + mqtt.ATTR_RETAIN: False, + }, + } + + # Test default/suggested values from config + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "broker" + defaults = { + mqtt.CONF_BROKER: "test-broker", + mqtt.CONF_PORT: 1234, + } + suggested = { + mqtt.CONF_USERNAME: "user", + mqtt.CONF_PASSWORD: "pass", + } + for k, v in defaults.items(): + assert get_default(result["data_schema"].schema, k) == v + for k, v in suggested.items(): + assert get_suggested(result["data_schema"].schema, k) == v + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mqtt.CONF_BROKER: "another-broker", + mqtt.CONF_PORT: 2345, + mqtt.CONF_USERNAME: "us3r", + mqtt.CONF_PASSWORD: "p4ss", + }, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "options" + + # check if the username and password was set from config flow and not from configuration.yaml + assert mock_try_connection_success.username_pw_set.mock_calls[0][1] == ( + "us3r", + "p4ss", + ) + + # check if tls_insecure_set is called + assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,) + + # check if the certificate settings were set from configuration.yaml + assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[ + "certfile" + ] == str(certfile) + assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[ + "keyfile" + ] == str(keyfile)