mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Fix MQTT config flow with advanced parameters (#67556)
* Fix MQTT config flow with advanced parameters * Add test
This commit is contained in:
parent
24e0c0b092
commit
9356bf1a8e
@ -70,11 +70,16 @@ from .const import (
|
|||||||
ATTR_TOPIC,
|
ATTR_TOPIC,
|
||||||
CONF_BIRTH_MESSAGE,
|
CONF_BIRTH_MESSAGE,
|
||||||
CONF_BROKER,
|
CONF_BROKER,
|
||||||
|
CONF_CERTIFICATE,
|
||||||
|
CONF_CLIENT_CERT,
|
||||||
|
CONF_CLIENT_KEY,
|
||||||
CONF_COMMAND_TOPIC,
|
CONF_COMMAND_TOPIC,
|
||||||
CONF_ENCODING,
|
CONF_ENCODING,
|
||||||
CONF_QOS,
|
CONF_QOS,
|
||||||
CONF_RETAIN,
|
CONF_RETAIN,
|
||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
|
CONF_TLS_INSECURE,
|
||||||
|
CONF_TLS_VERSION,
|
||||||
CONF_TOPIC,
|
CONF_TOPIC,
|
||||||
CONF_WILL_MESSAGE,
|
CONF_WILL_MESSAGE,
|
||||||
DATA_MQTT_CONFIG,
|
DATA_MQTT_CONFIG,
|
||||||
@ -89,6 +94,7 @@ from .const import (
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
MQTT_CONNECTED,
|
MQTT_CONNECTED,
|
||||||
MQTT_DISCONNECTED,
|
MQTT_DISCONNECTED,
|
||||||
|
PROTOCOL_31,
|
||||||
PROTOCOL_311,
|
PROTOCOL_311,
|
||||||
)
|
)
|
||||||
from .discovery import LAST_DISCOVERY
|
from .discovery import LAST_DISCOVERY
|
||||||
@ -113,13 +119,6 @@ SERVICE_DUMP = "dump"
|
|||||||
|
|
||||||
CONF_DISCOVERY_PREFIX = "discovery_prefix"
|
CONF_DISCOVERY_PREFIX = "discovery_prefix"
|
||||||
CONF_KEEPALIVE = "keepalive"
|
CONF_KEEPALIVE = "keepalive"
|
||||||
CONF_CERTIFICATE = "certificate"
|
|
||||||
CONF_CLIENT_KEY = "client_key"
|
|
||||||
CONF_CLIENT_CERT = "client_cert"
|
|
||||||
CONF_TLS_INSECURE = "tls_insecure"
|
|
||||||
CONF_TLS_VERSION = "tls_version"
|
|
||||||
|
|
||||||
PROTOCOL_31 = "3.1"
|
|
||||||
|
|
||||||
DEFAULT_PORT = 1883
|
DEFAULT_PORT = 1883
|
||||||
DEFAULT_KEEPALIVE = 60
|
DEFAULT_KEEPALIVE = 60
|
||||||
@ -751,6 +750,58 @@ class Subscription:
|
|||||||
encoding: str | None = attr.ib(default="utf-8")
|
encoding: str | None = attr.ib(default="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class MqttClientSetup:
|
||||||
|
"""Helper class to setup the paho mqtt client from config."""
|
||||||
|
|
||||||
|
# We don't import on the top because some integrations
|
||||||
|
# should be able to optionally rely on MQTT.
|
||||||
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
def __init__(self, config: ConfigType) -> None:
|
||||||
|
"""Initialize the MQTT client setup helper."""
|
||||||
|
|
||||||
|
if config[CONF_PROTOCOL] == PROTOCOL_31:
|
||||||
|
proto = self.mqtt.MQTTv31
|
||||||
|
else:
|
||||||
|
proto = self.mqtt.MQTTv311
|
||||||
|
|
||||||
|
if (client_id := config.get(CONF_CLIENT_ID)) is None:
|
||||||
|
# PAHO MQTT relies on the MQTT server to generate random client IDs.
|
||||||
|
# However, that feature is not mandatory so we generate our own.
|
||||||
|
client_id = self.mqtt.base62(uuid.uuid4().int, padding=22)
|
||||||
|
self._client = self.mqtt.Client(client_id, protocol=proto)
|
||||||
|
|
||||||
|
# Enable logging
|
||||||
|
self._client.enable_logger()
|
||||||
|
|
||||||
|
username = config.get(CONF_USERNAME)
|
||||||
|
password = config.get(CONF_PASSWORD)
|
||||||
|
if username is not None:
|
||||||
|
self._client.username_pw_set(username, password)
|
||||||
|
|
||||||
|
if (certificate := config.get(CONF_CERTIFICATE)) == "auto":
|
||||||
|
certificate = certifi.where()
|
||||||
|
|
||||||
|
client_key = config.get(CONF_CLIENT_KEY)
|
||||||
|
client_cert = config.get(CONF_CLIENT_CERT)
|
||||||
|
tls_insecure = config.get(CONF_TLS_INSECURE)
|
||||||
|
if certificate is not None:
|
||||||
|
self._client.tls_set(
|
||||||
|
certificate,
|
||||||
|
certfile=client_cert,
|
||||||
|
keyfile=client_key,
|
||||||
|
tls_version=ssl.PROTOCOL_TLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tls_insecure is not None:
|
||||||
|
self._client.tls_insecure_set(tls_insecure)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> mqtt.Client:
|
||||||
|
"""Return the paho MQTT client."""
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
|
||||||
class MQTT:
|
class MQTT:
|
||||||
"""Home Assistant MQTT client."""
|
"""Home Assistant MQTT client."""
|
||||||
|
|
||||||
@ -815,46 +866,7 @@ class MQTT:
|
|||||||
|
|
||||||
def init_client(self):
|
def init_client(self):
|
||||||
"""Initialize paho client."""
|
"""Initialize paho client."""
|
||||||
# We don't import on the top because some integrations
|
self._mqttc = MqttClientSetup(self.conf).client
|
||||||
# should be able to optionally rely on MQTT.
|
|
||||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
|
||||||
|
|
||||||
if self.conf[CONF_PROTOCOL] == PROTOCOL_31:
|
|
||||||
proto: int = mqtt.MQTTv31
|
|
||||||
else:
|
|
||||||
proto = mqtt.MQTTv311
|
|
||||||
|
|
||||||
if (client_id := self.conf.get(CONF_CLIENT_ID)) is None:
|
|
||||||
# PAHO MQTT relies on the MQTT server to generate random client IDs.
|
|
||||||
# However, that feature is not mandatory so we generate our own.
|
|
||||||
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
|
|
||||||
self._mqttc = mqtt.Client(client_id, protocol=proto)
|
|
||||||
|
|
||||||
# Enable logging
|
|
||||||
self._mqttc.enable_logger()
|
|
||||||
|
|
||||||
username = self.conf.get(CONF_USERNAME)
|
|
||||||
password = self.conf.get(CONF_PASSWORD)
|
|
||||||
if username is not None:
|
|
||||||
self._mqttc.username_pw_set(username, password)
|
|
||||||
|
|
||||||
if (certificate := self.conf.get(CONF_CERTIFICATE)) == "auto":
|
|
||||||
certificate = certifi.where()
|
|
||||||
|
|
||||||
client_key = self.conf.get(CONF_CLIENT_KEY)
|
|
||||||
client_cert = self.conf.get(CONF_CLIENT_CERT)
|
|
||||||
tls_insecure = self.conf.get(CONF_TLS_INSECURE)
|
|
||||||
if certificate is not None:
|
|
||||||
self._mqttc.tls_set(
|
|
||||||
certificate,
|
|
||||||
certfile=client_cert,
|
|
||||||
keyfile=client_key,
|
|
||||||
tls_version=ssl.PROTOCOL_TLS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tls_insecure is not None:
|
|
||||||
self._mqttc.tls_insecure_set(tls_insecure)
|
|
||||||
|
|
||||||
self._mqttc.on_connect = self._mqtt_on_connect
|
self._mqttc.on_connect = self._mqtt_on_connect
|
||||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||||
self._mqttc.on_message = self._mqtt_on_message
|
self._mqttc.on_message = self._mqtt_on_message
|
||||||
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
|||||||
)
|
)
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
|
|
||||||
|
from . import MqttClientSetup
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_PAYLOAD,
|
ATTR_PAYLOAD,
|
||||||
ATTR_QOS,
|
ATTR_QOS,
|
||||||
@ -62,6 +63,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
can_connect = await self.hass.async_add_executor_job(
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
try_connection,
|
try_connection,
|
||||||
|
self.hass,
|
||||||
user_input[CONF_BROKER],
|
user_input[CONF_BROKER],
|
||||||
user_input[CONF_PORT],
|
user_input[CONF_PORT],
|
||||||
user_input.get(CONF_USERNAME),
|
user_input.get(CONF_USERNAME),
|
||||||
@ -102,6 +104,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
data = self._hassio_discovery
|
data = self._hassio_discovery
|
||||||
can_connect = await self.hass.async_add_executor_job(
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
try_connection,
|
try_connection,
|
||||||
|
self.hass,
|
||||||
data[CONF_HOST],
|
data[CONF_HOST],
|
||||||
data[CONF_PORT],
|
data[CONF_PORT],
|
||||||
data.get(CONF_USERNAME),
|
data.get(CONF_USERNAME),
|
||||||
@ -152,6 +155,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
can_connect = await self.hass.async_add_executor_job(
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
try_connection,
|
try_connection,
|
||||||
|
self.hass,
|
||||||
user_input[CONF_BROKER],
|
user_input[CONF_BROKER],
|
||||||
user_input[CONF_PORT],
|
user_input[CONF_PORT],
|
||||||
user_input.get(CONF_USERNAME),
|
user_input.get(CONF_USERNAME),
|
||||||
@ -313,25 +317,24 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def try_connection(broker, port, username, password, protocol="3.1"):
|
def try_connection(hass, broker, port, username, password, protocol="3.1"):
|
||||||
"""Test if we can connect to an MQTT broker."""
|
"""Test if we can connect to an MQTT broker."""
|
||||||
# pylint: disable-next=import-outside-toplevel
|
# Get the config from configuration.yaml
|
||||||
import paho.mqtt.client as mqtt
|
yaml_config = hass.data.get(DATA_MQTT_CONFIG, {})
|
||||||
|
entry_config = {
|
||||||
if protocol == "3.1":
|
CONF_BROKER: broker,
|
||||||
proto = mqtt.MQTTv31
|
CONF_PORT: port,
|
||||||
else:
|
CONF_USERNAME: username,
|
||||||
proto = mqtt.MQTTv311
|
CONF_PASSWORD: password,
|
||||||
|
CONF_PROTOCOL: protocol,
|
||||||
client = mqtt.Client(protocol=proto)
|
}
|
||||||
if username and password:
|
client = MqttClientSetup({**yaml_config, **entry_config}).client
|
||||||
client.username_pw_set(username, password)
|
|
||||||
|
|
||||||
result = queue.Queue(maxsize=1)
|
result = queue.Queue(maxsize=1)
|
||||||
|
|
||||||
def on_connect(client_, userdata, flags, result_code):
|
def on_connect(client_, userdata, flags, result_code):
|
||||||
"""Handle connection result."""
|
"""Handle connection result."""
|
||||||
result.put(result_code == mqtt.CONNACK_ACCEPTED)
|
result.put(result_code == MqttClientSetup.mqtt.CONNACK_ACCEPTED)
|
||||||
|
|
||||||
client.on_connect = on_connect
|
client.on_connect = on_connect
|
||||||
|
|
||||||
|
@ -22,6 +22,12 @@ CONF_STATE_VALUE_TEMPLATE = "state_value_template"
|
|||||||
CONF_TOPIC = "topic"
|
CONF_TOPIC = "topic"
|
||||||
CONF_WILL_MESSAGE = "will_message"
|
CONF_WILL_MESSAGE = "will_message"
|
||||||
|
|
||||||
|
CONF_CERTIFICATE = "certificate"
|
||||||
|
CONF_CLIENT_KEY = "client_key"
|
||||||
|
CONF_CLIENT_CERT = "client_cert"
|
||||||
|
CONF_TLS_INSECURE = "tls_insecure"
|
||||||
|
CONF_TLS_VERSION = "tls_version"
|
||||||
|
|
||||||
DATA_MQTT_CONFIG = "mqtt_config"
|
DATA_MQTT_CONFIG = "mqtt_config"
|
||||||
DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed"
|
DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed"
|
||||||
|
|
||||||
@ -56,4 +62,5 @@ MQTT_DISCONNECTED = "mqtt_disconnected"
|
|||||||
PAYLOAD_EMPTY_JSON = "{}"
|
PAYLOAD_EMPTY_JSON = "{}"
|
||||||
PAYLOAD_NONE = "None"
|
PAYLOAD_NONE = "None"
|
||||||
|
|
||||||
|
PROTOCOL_31 = "3.1"
|
||||||
PROTOCOL_311 = "3.1.1"
|
PROTOCOL_311 = "3.1.1"
|
||||||
|
@ -3,8 +3,9 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
import yaml
|
||||||
|
|
||||||
from homeassistant import config_entries, data_entry_flow
|
from homeassistant import config as hass_config, config_entries, data_entry_flow
|
||||||
from homeassistant.components import mqtt
|
from homeassistant.components import mqtt
|
||||||
from homeassistant.components.hassio import HassioServiceInfo
|
from homeassistant.components.hassio import HassioServiceInfo
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -151,7 +152,7 @@ async def test_manual_config_set(
|
|||||||
"discovery": True,
|
"discovery": True,
|
||||||
}
|
}
|
||||||
# Check we tried the connection, with precedence for config entry settings
|
# Check we tried the connection, with precedence for config entry settings
|
||||||
mock_try_connection.assert_called_once_with("127.0.0.1", 1883, None, None)
|
mock_try_connection.assert_called_once_with(hass, "127.0.0.1", 1883, None, None)
|
||||||
# Check config entry got setup
|
# Check config entry got setup
|
||||||
assert len(mock_finish_setup.mock_calls) == 1
|
assert len(mock_finish_setup.mock_calls) == 1
|
||||||
|
|
||||||
@ -642,3 +643,95 @@ async def test_options_bad_will_message_fails(hass, mock_try_connection):
|
|||||||
mqtt.CONF_BROKER: "test-broker",
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
mqtt.CONF_PORT: 1234,
|
mqtt.CONF_PORT: 1234,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_try_connection_with_advanced_parameters(
|
||||||
|
hass, mock_try_connection_success, tmp_path
|
||||||
|
):
|
||||||
|
"""Test config flow with advanced parameters from config."""
|
||||||
|
# Mock certificate files
|
||||||
|
certfile = tmp_path / "cert.pem"
|
||||||
|
certfile.write_text("## mock certificate file ##")
|
||||||
|
keyfile = tmp_path / "key.pem"
|
||||||
|
keyfile.write_text("## mock key file ##")
|
||||||
|
config = {
|
||||||
|
"certificate": "auto",
|
||||||
|
"tls_insecure": True,
|
||||||
|
"client_cert": certfile,
|
||||||
|
"client_key": keyfile,
|
||||||
|
}
|
||||||
|
new_yaml_config_file = tmp_path / "configuration.yaml"
|
||||||
|
new_yaml_config = yaml.dump({mqtt.DOMAIN: config})
|
||||||
|
new_yaml_config_file.write_text(new_yaml_config)
|
||||||
|
assert new_yaml_config_file.read_text() == new_yaml_config
|
||||||
|
|
||||||
|
with patch.object(hass_config, "YAML_CONFIG_FILE", new_yaml_config_file):
|
||||||
|
await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
config_entry.data = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
mqtt.CONF_USERNAME: "user",
|
||||||
|
mqtt.CONF_PASSWORD: "pass",
|
||||||
|
mqtt.CONF_DISCOVERY: True,
|
||||||
|
mqtt.CONF_BIRTH_MESSAGE: {
|
||||||
|
mqtt.ATTR_TOPIC: "ha_state/online",
|
||||||
|
mqtt.ATTR_PAYLOAD: "online",
|
||||||
|
mqtt.ATTR_QOS: 1,
|
||||||
|
mqtt.ATTR_RETAIN: True,
|
||||||
|
},
|
||||||
|
mqtt.CONF_WILL_MESSAGE: {
|
||||||
|
mqtt.ATTR_TOPIC: "ha_state/offline",
|
||||||
|
mqtt.ATTR_PAYLOAD: "offline",
|
||||||
|
mqtt.ATTR_QOS: 2,
|
||||||
|
mqtt.ATTR_RETAIN: False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test default/suggested values from config
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
defaults = {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
mqtt.CONF_PORT: 1234,
|
||||||
|
}
|
||||||
|
suggested = {
|
||||||
|
mqtt.CONF_USERNAME: "user",
|
||||||
|
mqtt.CONF_PASSWORD: "pass",
|
||||||
|
}
|
||||||
|
for k, v in defaults.items():
|
||||||
|
assert get_default(result["data_schema"].schema, k) == v
|
||||||
|
for k, v in suggested.items():
|
||||||
|
assert get_suggested(result["data_schema"].schema, k) == v
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_BROKER: "another-broker",
|
||||||
|
mqtt.CONF_PORT: 2345,
|
||||||
|
mqtt.CONF_USERNAME: "us3r",
|
||||||
|
mqtt.CONF_PASSWORD: "p4ss",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||||
|
assert result["step_id"] == "options"
|
||||||
|
|
||||||
|
# check if the username and password was set from config flow and not from configuration.yaml
|
||||||
|
assert mock_try_connection_success.username_pw_set.mock_calls[0][1] == (
|
||||||
|
"us3r",
|
||||||
|
"p4ss",
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if tls_insecure_set is called
|
||||||
|
assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,)
|
||||||
|
|
||||||
|
# check if the certificate settings were set from configuration.yaml
|
||||||
|
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
|
||||||
|
"certfile"
|
||||||
|
] == str(certfile)
|
||||||
|
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
|
||||||
|
"keyfile"
|
||||||
|
] == str(keyfile)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user