mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Fix MQTT config flow with advanced parameters (#67556)
* Fix MQTT config flow with advanced parameters * Add test
This commit is contained in:
parent
24e0c0b092
commit
9356bf1a8e
@ -70,11 +70,16 @@ from .const import (
|
||||
ATTR_TOPIC,
|
||||
CONF_BIRTH_MESSAGE,
|
||||
CONF_BROKER,
|
||||
CONF_CERTIFICATE,
|
||||
CONF_CLIENT_CERT,
|
||||
CONF_CLIENT_KEY,
|
||||
CONF_COMMAND_TOPIC,
|
||||
CONF_ENCODING,
|
||||
CONF_QOS,
|
||||
CONF_RETAIN,
|
||||
CONF_STATE_TOPIC,
|
||||
CONF_TLS_INSECURE,
|
||||
CONF_TLS_VERSION,
|
||||
CONF_TOPIC,
|
||||
CONF_WILL_MESSAGE,
|
||||
DATA_MQTT_CONFIG,
|
||||
@ -89,6 +94,7 @@ from .const import (
|
||||
DOMAIN,
|
||||
MQTT_CONNECTED,
|
||||
MQTT_DISCONNECTED,
|
||||
PROTOCOL_31,
|
||||
PROTOCOL_311,
|
||||
)
|
||||
from .discovery import LAST_DISCOVERY
|
||||
@ -113,13 +119,6 @@ SERVICE_DUMP = "dump"
|
||||
|
||||
CONF_DISCOVERY_PREFIX = "discovery_prefix"
|
||||
CONF_KEEPALIVE = "keepalive"
|
||||
CONF_CERTIFICATE = "certificate"
|
||||
CONF_CLIENT_KEY = "client_key"
|
||||
CONF_CLIENT_CERT = "client_cert"
|
||||
CONF_TLS_INSECURE = "tls_insecure"
|
||||
CONF_TLS_VERSION = "tls_version"
|
||||
|
||||
PROTOCOL_31 = "3.1"
|
||||
|
||||
DEFAULT_PORT = 1883
|
||||
DEFAULT_KEEPALIVE = 60
|
||||
@ -751,6 +750,58 @@ class Subscription:
|
||||
encoding: str | None = attr.ib(default="utf-8")
|
||||
|
||||
|
||||
class MqttClientSetup:
|
||||
"""Helper class to setup the paho mqtt client from config."""
|
||||
|
||||
# We don't import on the top because some integrations
|
||||
# should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
def __init__(self, config: ConfigType) -> None:
|
||||
"""Initialize the MQTT client setup helper."""
|
||||
|
||||
if config[CONF_PROTOCOL] == PROTOCOL_31:
|
||||
proto = self.mqtt.MQTTv31
|
||||
else:
|
||||
proto = self.mqtt.MQTTv311
|
||||
|
||||
if (client_id := config.get(CONF_CLIENT_ID)) is None:
|
||||
# PAHO MQTT relies on the MQTT server to generate random client IDs.
|
||||
# However, that feature is not mandatory so we generate our own.
|
||||
client_id = self.mqtt.base62(uuid.uuid4().int, padding=22)
|
||||
self._client = self.mqtt.Client(client_id, protocol=proto)
|
||||
|
||||
# Enable logging
|
||||
self._client.enable_logger()
|
||||
|
||||
username = config.get(CONF_USERNAME)
|
||||
password = config.get(CONF_PASSWORD)
|
||||
if username is not None:
|
||||
self._client.username_pw_set(username, password)
|
||||
|
||||
if (certificate := config.get(CONF_CERTIFICATE)) == "auto":
|
||||
certificate = certifi.where()
|
||||
|
||||
client_key = config.get(CONF_CLIENT_KEY)
|
||||
client_cert = config.get(CONF_CLIENT_CERT)
|
||||
tls_insecure = config.get(CONF_TLS_INSECURE)
|
||||
if certificate is not None:
|
||||
self._client.tls_set(
|
||||
certificate,
|
||||
certfile=client_cert,
|
||||
keyfile=client_key,
|
||||
tls_version=ssl.PROTOCOL_TLS,
|
||||
)
|
||||
|
||||
if tls_insecure is not None:
|
||||
self._client.tls_insecure_set(tls_insecure)
|
||||
|
||||
@property
|
||||
def client(self) -> mqtt.Client:
|
||||
"""Return the paho MQTT client."""
|
||||
return self._client
|
||||
|
||||
|
||||
class MQTT:
|
||||
"""Home Assistant MQTT client."""
|
||||
|
||||
@ -815,46 +866,7 @@ class MQTT:
|
||||
|
||||
def init_client(self):
|
||||
"""Initialize paho client."""
|
||||
# We don't import on the top because some integrations
|
||||
# should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
|
||||
proto: int = mqtt.MQTTv31
|
||||
else:
|
||||
proto = mqtt.MQTTv311
|
||||
|
||||
if (client_id := self.conf.get(CONF_CLIENT_ID)) is None:
|
||||
# PAHO MQTT relies on the MQTT server to generate random client IDs.
|
||||
# However, that feature is not mandatory so we generate our own.
|
||||
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
||||
self._mqttc = mqtt.Client(client_id, protocol=proto)
|
||||
|
||||
# Enable logging
|
||||
self._mqttc.enable_logger()
|
||||
|
||||
username = self.conf.get(CONF_USERNAME)
|
||||
password = self.conf.get(CONF_PASSWORD)
|
||||
if username is not None:
|
||||
self._mqttc.username_pw_set(username, password)
|
||||
|
||||
if (certificate := self.conf.get(CONF_CERTIFICATE)) == "auto":
|
||||
certificate = certifi.where()
|
||||
|
||||
client_key = self.conf.get(CONF_CLIENT_KEY)
|
||||
client_cert = self.conf.get(CONF_CLIENT_CERT)
|
||||
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
|
||||
if certificate is not None:
|
||||
self._mqttc.tls_set(
|
||||
certificate,
|
||||
certfile=client_cert,
|
||||
keyfile=client_key,
|
||||
tls_version=ssl.PROTOCOL_TLS,
|
||||
)
|
||||
|
||||
if tls_insecure is not None:
|
||||
self._mqttc.tls_insecure_set(tls_insecure)
|
||||
|
||||
self._mqttc = MqttClientSetup(self.conf).client
|
||||
self._mqttc.on_connect = self._mqtt_on_connect
|
||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||
self._mqttc.on_message = self._mqtt_on_message
|
||||
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from . import MqttClientSetup
|
||||
from .const import (
|
||||
ATTR_PAYLOAD,
|
||||
ATTR_QOS,
|
||||
@ -62,6 +63,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
if user_input is not None:
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
self.hass,
|
||||
user_input[CONF_BROKER],
|
||||
user_input[CONF_PORT],
|
||||
user_input.get(CONF_USERNAME),
|
||||
@ -102,6 +104,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
data = self._hassio_discovery
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
self.hass,
|
||||
data[CONF_HOST],
|
||||
data[CONF_PORT],
|
||||
data.get(CONF_USERNAME),
|
||||
@ -152,6 +155,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
if user_input is not None:
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
self.hass,
|
||||
user_input[CONF_BROKER],
|
||||
user_input[CONF_PORT],
|
||||
user_input.get(CONF_USERNAME),
|
||||
@ -313,25 +317,24 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
)
|
||||
|
||||
|
||||
def try_connection(broker, port, username, password, protocol="3.1"):
|
||||
def try_connection(hass, broker, port, username, password, protocol="3.1"):
|
||||
"""Test if we can connect to an MQTT broker."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
if protocol == "3.1":
|
||||
proto = mqtt.MQTTv31
|
||||
else:
|
||||
proto = mqtt.MQTTv311
|
||||
|
||||
client = mqtt.Client(protocol=proto)
|
||||
if username and password:
|
||||
client.username_pw_set(username, password)
|
||||
# Get the config from configuration.yaml
|
||||
yaml_config = hass.data.get(DATA_MQTT_CONFIG, {})
|
||||
entry_config = {
|
||||
CONF_BROKER: broker,
|
||||
CONF_PORT: port,
|
||||
CONF_USERNAME: username,
|
||||
CONF_PASSWORD: password,
|
||||
CONF_PROTOCOL: protocol,
|
||||
}
|
||||
client = MqttClientSetup({**yaml_config, **entry_config}).client
|
||||
|
||||
result = queue.Queue(maxsize=1)
|
||||
|
||||
def on_connect(client_, userdata, flags, result_code):
|
||||
"""Handle connection result."""
|
||||
result.put(result_code == mqtt.CONNACK_ACCEPTED)
|
||||
result.put(result_code == MqttClientSetup.mqtt.CONNACK_ACCEPTED)
|
||||
|
||||
client.on_connect = on_connect
|
||||
|
||||
|
@ -22,6 +22,12 @@ CONF_STATE_VALUE_TEMPLATE = "state_value_template"
|
||||
CONF_TOPIC = "topic"
|
||||
CONF_WILL_MESSAGE = "will_message"
|
||||
|
||||
CONF_CERTIFICATE = "certificate"
|
||||
CONF_CLIENT_KEY = "client_key"
|
||||
CONF_CLIENT_CERT = "client_cert"
|
||||
CONF_TLS_INSECURE = "tls_insecure"
|
||||
CONF_TLS_VERSION = "tls_version"
|
||||
|
||||
DATA_MQTT_CONFIG = "mqtt_config"
|
||||
DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed"
|
||||
|
||||
@ -56,4 +62,5 @@ MQTT_DISCONNECTED = "mqtt_disconnected"
|
||||
PAYLOAD_EMPTY_JSON = "{}"
|
||||
PAYLOAD_NONE = "None"
|
||||
|
||||
PROTOCOL_31 = "3.1"
|
||||
PROTOCOL_311 = "3.1.1"
|
||||
|
@ -3,8 +3,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
import yaml
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant import config as hass_config, config_entries, data_entry_flow
|
||||
from homeassistant.components import mqtt
|
||||
from homeassistant.components.hassio import HassioServiceInfo
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -151,7 +152,7 @@ async def test_manual_config_set(
|
||||
"discovery": True,
|
||||
}
|
||||
# Check we tried the connection, with precedence for config entry settings
|
||||
mock_try_connection.assert_called_once_with("127.0.0.1", 1883, None, None)
|
||||
mock_try_connection.assert_called_once_with(hass, "127.0.0.1", 1883, None, None)
|
||||
# Check config entry got setup
|
||||
assert len(mock_finish_setup.mock_calls) == 1
|
||||
|
||||
@ -642,3 +643,95 @@ async def test_options_bad_will_message_fails(hass, mock_try_connection):
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
|
||||
|
||||
async def test_try_connection_with_advanced_parameters(
|
||||
hass, mock_try_connection_success, tmp_path
|
||||
):
|
||||
"""Test config flow with advanced parameters from config."""
|
||||
# Mock certificate files
|
||||
certfile = tmp_path / "cert.pem"
|
||||
certfile.write_text("## mock certificate file ##")
|
||||
keyfile = tmp_path / "key.pem"
|
||||
keyfile.write_text("## mock key file ##")
|
||||
config = {
|
||||
"certificate": "auto",
|
||||
"tls_insecure": True,
|
||||
"client_cert": certfile,
|
||||
"client_key": keyfile,
|
||||
}
|
||||
new_yaml_config_file = tmp_path / "configuration.yaml"
|
||||
new_yaml_config = yaml.dump({mqtt.DOMAIN: config})
|
||||
new_yaml_config_file.write_text(new_yaml_config)
|
||||
assert new_yaml_config_file.read_text() == new_yaml_config
|
||||
|
||||
with patch.object(hass_config, "YAML_CONFIG_FILE", new_yaml_config_file):
|
||||
await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
|
||||
await hass.async_block_till_done()
|
||||
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry.data = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
mqtt.CONF_USERNAME: "user",
|
||||
mqtt.CONF_PASSWORD: "pass",
|
||||
mqtt.CONF_DISCOVERY: True,
|
||||
mqtt.CONF_BIRTH_MESSAGE: {
|
||||
mqtt.ATTR_TOPIC: "ha_state/online",
|
||||
mqtt.ATTR_PAYLOAD: "online",
|
||||
mqtt.ATTR_QOS: 1,
|
||||
mqtt.ATTR_RETAIN: True,
|
||||
},
|
||||
mqtt.CONF_WILL_MESSAGE: {
|
||||
mqtt.ATTR_TOPIC: "ha_state/offline",
|
||||
mqtt.ATTR_PAYLOAD: "offline",
|
||||
mqtt.ATTR_QOS: 2,
|
||||
mqtt.ATTR_RETAIN: False,
|
||||
},
|
||||
}
|
||||
|
||||
# Test default/suggested values from config
|
||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "broker"
|
||||
defaults = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
mqtt.CONF_PORT: 1234,
|
||||
}
|
||||
suggested = {
|
||||
mqtt.CONF_USERNAME: "user",
|
||||
mqtt.CONF_PASSWORD: "pass",
|
||||
}
|
||||
for k, v in defaults.items():
|
||||
assert get_default(result["data_schema"].schema, k) == v
|
||||
for k, v in suggested.items():
|
||||
assert get_suggested(result["data_schema"].schema, k) == v
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_BROKER: "another-broker",
|
||||
mqtt.CONF_PORT: 2345,
|
||||
mqtt.CONF_USERNAME: "us3r",
|
||||
mqtt.CONF_PASSWORD: "p4ss",
|
||||
},
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "options"
|
||||
|
||||
# check if the username and password was set from config flow and not from configuration.yaml
|
||||
assert mock_try_connection_success.username_pw_set.mock_calls[0][1] == (
|
||||
"us3r",
|
||||
"p4ss",
|
||||
)
|
||||
|
||||
# check if tls_insecure_set is called
|
||||
assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,)
|
||||
|
||||
# check if the certificate settings were set from configuration.yaml
|
||||
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
|
||||
"certfile"
|
||||
] == str(certfile)
|
||||
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
|
||||
"keyfile"
|
||||
] == str(keyfile)
|
||||
|
Loading…
x
Reference in New Issue
Block a user