Fix MQTT config flow with advanced parameters (#67556)

* Fix MQTT config flow with advanced parameters

* Add test
This commit is contained in:
Jan Bouwhuis 2022-03-03 21:40:15 +01:00 committed by GitHub
parent 24e0c0b092
commit 9356bf1a8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 177 additions and 62 deletions

View File

@ -70,11 +70,16 @@ from .const import (
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_CERTIFICATE,
CONF_CLIENT_CERT,
CONF_CLIENT_KEY,
CONF_COMMAND_TOPIC,
CONF_ENCODING,
CONF_QOS,
CONF_RETAIN,
CONF_STATE_TOPIC,
CONF_TLS_INSECURE,
CONF_TLS_VERSION,
CONF_TOPIC,
CONF_WILL_MESSAGE,
DATA_MQTT_CONFIG,
@ -89,6 +94,7 @@ from .const import (
DOMAIN,
MQTT_CONNECTED,
MQTT_DISCONNECTED,
PROTOCOL_31,
PROTOCOL_311,
)
from .discovery import LAST_DISCOVERY
@ -113,13 +119,6 @@ SERVICE_DUMP = "dump"
CONF_DISCOVERY_PREFIX = "discovery_prefix"
CONF_KEEPALIVE = "keepalive"
CONF_CERTIFICATE = "certificate"
CONF_CLIENT_KEY = "client_key"
CONF_CLIENT_CERT = "client_cert"
CONF_TLS_INSECURE = "tls_insecure"
CONF_TLS_VERSION = "tls_version"
PROTOCOL_31 = "3.1"
DEFAULT_PORT = 1883
DEFAULT_KEEPALIVE = 60
@ -751,6 +750,58 @@ class Subscription:
encoding: str | None = attr.ib(default="utf-8")
class MqttClientSetup:
"""Helper class to setup the paho mqtt client from config."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
def __init__(self, config: ConfigType) -> None:
"""Initialize the MQTT client setup helper."""
if config[CONF_PROTOCOL] == PROTOCOL_31:
proto = self.mqtt.MQTTv31
else:
proto = self.mqtt.MQTTv311
if (client_id := config.get(CONF_CLIENT_ID)) is None:
# PAHO MQTT relies on the MQTT server to generate random client IDs.
# However, that feature is not mandatory so we generate our own.
client_id = self.mqtt.base62(uuid.uuid4().int, padding=22)
self._client = self.mqtt.Client(client_id, protocol=proto)
# Enable logging
self._client.enable_logger()
username = config.get(CONF_USERNAME)
password = config.get(CONF_PASSWORD)
if username is not None:
self._client.username_pw_set(username, password)
if (certificate := config.get(CONF_CERTIFICATE)) == "auto":
certificate = certifi.where()
client_key = config.get(CONF_CLIENT_KEY)
client_cert = config.get(CONF_CLIENT_CERT)
tls_insecure = config.get(CONF_TLS_INSECURE)
if certificate is not None:
self._client.tls_set(
certificate,
certfile=client_cert,
keyfile=client_key,
tls_version=ssl.PROTOCOL_TLS,
)
if tls_insecure is not None:
self._client.tls_insecure_set(tls_insecure)
@property
def client(self) -> mqtt.Client:
"""Return the paho MQTT client."""
return self._client
class MQTT:
"""Home Assistant MQTT client."""
@ -815,46 +866,7 @@ class MQTT:
def init_client(self):
"""Initialize paho client."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
proto: int = mqtt.MQTTv31
else:
proto = mqtt.MQTTv311
if (client_id := self.conf.get(CONF_CLIENT_ID)) is None:
# PAHO MQTT relies on the MQTT server to generate random client IDs.
# However, that feature is not mandatory so we generate our own.
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
self._mqttc = mqtt.Client(client_id, protocol=proto)
# Enable logging
self._mqttc.enable_logger()
username = self.conf.get(CONF_USERNAME)
password = self.conf.get(CONF_PASSWORD)
if username is not None:
self._mqttc.username_pw_set(username, password)
if (certificate := self.conf.get(CONF_CERTIFICATE)) == "auto":
certificate = certifi.where()
client_key = self.conf.get(CONF_CLIENT_KEY)
client_cert = self.conf.get(CONF_CLIENT_CERT)
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
if certificate is not None:
self._mqttc.tls_set(
certificate,
certfile=client_cert,
keyfile=client_key,
tls_version=ssl.PROTOCOL_TLS,
)
if tls_insecure is not None:
self._mqttc.tls_insecure_set(tls_insecure)
self._mqttc = MqttClientSetup(self.conf).client
self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message

View File

@ -17,6 +17,7 @@ from homeassistant.const import (
)
from homeassistant.data_entry_flow import FlowResult
from . import MqttClientSetup
from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
@ -62,6 +63,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None:
can_connect = await self.hass.async_add_executor_job(
try_connection,
self.hass,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
@ -102,6 +104,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
data = self._hassio_discovery
can_connect = await self.hass.async_add_executor_job(
try_connection,
self.hass,
data[CONF_HOST],
data[CONF_PORT],
data.get(CONF_USERNAME),
@ -152,6 +155,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
if user_input is not None:
can_connect = await self.hass.async_add_executor_job(
try_connection,
self.hass,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
@ -313,25 +317,24 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
)
def try_connection(broker, port, username, password, protocol="3.1"):
def try_connection(hass, broker, port, username, password, protocol="3.1"):
"""Test if we can connect to an MQTT broker."""
# pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt
if protocol == "3.1":
proto = mqtt.MQTTv31
else:
proto = mqtt.MQTTv311
client = mqtt.Client(protocol=proto)
if username and password:
client.username_pw_set(username, password)
# Get the config from configuration.yaml
yaml_config = hass.data.get(DATA_MQTT_CONFIG, {})
entry_config = {
CONF_BROKER: broker,
CONF_PORT: port,
CONF_USERNAME: username,
CONF_PASSWORD: password,
CONF_PROTOCOL: protocol,
}
client = MqttClientSetup({**yaml_config, **entry_config}).client
result = queue.Queue(maxsize=1)
def on_connect(client_, userdata, flags, result_code):
"""Handle connection result."""
result.put(result_code == mqtt.CONNACK_ACCEPTED)
result.put(result_code == MqttClientSetup.mqtt.CONNACK_ACCEPTED)
client.on_connect = on_connect

View File

@ -22,6 +22,12 @@ CONF_STATE_VALUE_TEMPLATE = "state_value_template"
CONF_TOPIC = "topic"
CONF_WILL_MESSAGE = "will_message"
CONF_CERTIFICATE = "certificate"
CONF_CLIENT_KEY = "client_key"
CONF_CLIENT_CERT = "client_cert"
CONF_TLS_INSECURE = "tls_insecure"
CONF_TLS_VERSION = "tls_version"
DATA_MQTT_CONFIG = "mqtt_config"
DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed"
@ -56,4 +62,5 @@ MQTT_DISCONNECTED = "mqtt_disconnected"
PAYLOAD_EMPTY_JSON = "{}"
PAYLOAD_NONE = "None"
PROTOCOL_31 = "3.1"
PROTOCOL_311 = "3.1.1"

View File

@ -3,8 +3,9 @@ from unittest.mock import patch
import pytest
import voluptuous as vol
import yaml
from homeassistant import config_entries, data_entry_flow
from homeassistant import config as hass_config, config_entries, data_entry_flow
from homeassistant.components import mqtt
from homeassistant.components.hassio import HassioServiceInfo
from homeassistant.core import HomeAssistant
@ -151,7 +152,7 @@ async def test_manual_config_set(
"discovery": True,
}
# Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with("127.0.0.1", 1883, None, None)
mock_try_connection.assert_called_once_with(hass, "127.0.0.1", 1883, None, None)
# Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 1
@ -642,3 +643,95 @@ async def test_options_bad_will_message_fails(hass, mock_try_connection):
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
async def test_try_connection_with_advanced_parameters(
hass, mock_try_connection_success, tmp_path
):
"""Test config flow with advanced parameters from config."""
# Mock certificate files
certfile = tmp_path / "cert.pem"
certfile.write_text("## mock certificate file ##")
keyfile = tmp_path / "key.pem"
keyfile.write_text("## mock key file ##")
config = {
"certificate": "auto",
"tls_insecure": True,
"client_cert": certfile,
"client_key": keyfile,
}
new_yaml_config_file = tmp_path / "configuration.yaml"
new_yaml_config = yaml.dump({mqtt.DOMAIN: config})
new_yaml_config_file.write_text(new_yaml_config)
assert new_yaml_config_file.read_text() == new_yaml_config
with patch.object(hass_config, "YAML_CONFIG_FILE", new_yaml_config_file):
await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
config_entry.data = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
mqtt.CONF_DISCOVERY: True,
mqtt.CONF_BIRTH_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/online",
mqtt.ATTR_PAYLOAD: "online",
mqtt.ATTR_QOS: 1,
mqtt.ATTR_RETAIN: True,
},
mqtt.CONF_WILL_MESSAGE: {
mqtt.ATTR_TOPIC: "ha_state/offline",
mqtt.ATTR_PAYLOAD: "offline",
mqtt.ATTR_QOS: 2,
mqtt.ATTR_RETAIN: False,
},
}
# Test default/suggested values from config
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "broker"
defaults = {
mqtt.CONF_BROKER: "test-broker",
mqtt.CONF_PORT: 1234,
}
suggested = {
mqtt.CONF_USERNAME: "user",
mqtt.CONF_PASSWORD: "pass",
}
for k, v in defaults.items():
assert get_default(result["data_schema"].schema, k) == v
for k, v in suggested.items():
assert get_suggested(result["data_schema"].schema, k) == v
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "another-broker",
mqtt.CONF_PORT: 2345,
mqtt.CONF_USERNAME: "us3r",
mqtt.CONF_PASSWORD: "p4ss",
},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "options"
# check if the username and password was set from config flow and not from configuration.yaml
assert mock_try_connection_success.username_pw_set.mock_calls[0][1] == (
"us3r",
"p4ss",
)
# check if tls_insecure_set is called
assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,)
# check if the certificate settings were set from configuration.yaml
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
"certfile"
] == str(certfile)
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
"keyfile"
] == str(keyfile)