diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index bd734318938..b0ce53d75fd 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -54,6 +54,7 @@ from .const import ( CONF_TLS_INSECURE, CONF_WILL_MESSAGE, DEFAULT_ENCODING, + DEFAULT_PROTOCOL, DEFAULT_QOS, MQTT_CONNECTED, MQTT_DISCONNECTED, @@ -272,7 +273,7 @@ class MqttClientSetup: # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel - if config[CONF_PROTOCOL] == PROTOCOL_31: + if config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) == PROTOCOL_31: proto = mqtt.MQTTv31 else: proto = mqtt.MQTTv311 diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 5d21619c498..df7b6137549 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -2,7 +2,9 @@ from __future__ import annotations from collections import OrderedDict +from collections.abc import Callable import queue +from types import MappingProxyType from typing import Any import voluptuous as vol @@ -15,10 +17,9 @@ from homeassistant.const import ( CONF_PASSWORD, CONF_PAYLOAD, CONF_PORT, - CONF_PROTOCOL, CONF_USERNAME, ) -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.typing import ConfigType @@ -33,6 +34,7 @@ from .const import ( CONF_WILL_MESSAGE, DEFAULT_BIRTH, DEFAULT_DISCOVERY, + DEFAULT_PORT, DEFAULT_WILL, DOMAIN, ) @@ -56,9 +58,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Get the options flow for this handler.""" return MQTTOptionsFlowHandler(config_entry) - async def async_step_user( - self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + async def async_step_user(self, user_input: ConfigType | None = None) -> FlowResult: """Handle a flow initialized by the user.""" if self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -66,35 +66,38 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): return await self.async_step_broker() async def async_step_broker( - self, user_input: dict[str, Any] | None = None + self, user_input: ConfigType | None = None ) -> FlowResult: """Confirm the setup.""" - errors = {} - - if user_input is not None: + yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {} + errors: dict[str, str] = {} + fields: OrderedDict[Any, Any] = OrderedDict() + validated_user_input: ConfigType = {} + if await async_get_broker_settings( + self.hass, + fields, + yaml_config, + None, + user_input, + validated_user_input, + errors, + ): + test_config: ConfigType = yaml_config.copy() + test_config.update(validated_user_input) can_connect = await self.hass.async_add_executor_job( try_connection, - get_mqtt_data(self.hass, True).config or {}, - user_input[CONF_BROKER], - user_input[CONF_PORT], - user_input.get(CONF_USERNAME), - user_input.get(CONF_PASSWORD), + test_config, ) if can_connect: - user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY + validated_user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY return self.async_create_entry( - title=user_input[CONF_BROKER], data=user_input + title=validated_user_input[CONF_BROKER], + data=validated_user_input, ) errors["base"] = "cannot_connect" - fields = OrderedDict() - fields[vol.Required(CONF_BROKER)] = str - fields[vol.Required(CONF_PORT, default=1883)] = vol.Coerce(int) - fields[vol.Optional(CONF_USERNAME)] = str - fields[vol.Optional(CONF_PASSWORD)] = str - return self.async_show_form( step_id="broker", data_schema=vol.Schema(fields), errors=errors ) @@ -111,26 +114,22 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Confirm a Hass.io discovery.""" - errors = {} + errors: dict[str, str] = {} assert self._hassio_discovery if user_input is not None: - data = self._hassio_discovery + data: ConfigType = self._hassio_discovery.copy() + data[CONF_BROKER] = data.pop(CONF_HOST) can_connect = await self.hass.async_add_executor_job( try_connection, - get_mqtt_data(self.hass, True).config or {}, - data[CONF_HOST], - data[CONF_PORT], - data.get(CONF_USERNAME), - data.get(CONF_PASSWORD), - data.get(CONF_PROTOCOL), + data, ) if can_connect: return self.async_create_entry( title=data["addon"], data={ - CONF_BROKER: data[CONF_HOST], + CONF_BROKER: data[CONF_BROKER], CONF_PORT: data[CONF_PORT], CONF_USERNAME: data.get(CONF_USERNAME), CONF_PASSWORD: data.get(CONF_PASSWORD), @@ -164,46 +163,32 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Manage the MQTT broker configuration.""" - mqtt_data = get_mqtt_data(self.hass, True) - yaml_config = mqtt_data.config or {} - errors = {} - current_config = self.config_entry.data - if user_input is not None: + errors: dict[str, str] = {} + yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {} + fields: OrderedDict[Any, Any] = OrderedDict() + validated_user_input: ConfigType = {} + if await async_get_broker_settings( + self.hass, + fields, + yaml_config, + self.config_entry.data, + user_input, + validated_user_input, + errors, + ): + test_config: ConfigType = yaml_config.copy() + test_config.update(validated_user_input) can_connect = await self.hass.async_add_executor_job( try_connection, - yaml_config, - user_input[CONF_BROKER], - user_input[CONF_PORT], - user_input.get(CONF_USERNAME), - user_input.get(CONF_PASSWORD), + test_config, ) if can_connect: - self.broker_config.update(user_input) + self.broker_config.update(validated_user_input) return await self.async_step_options() errors["base"] = "cannot_connect" - fields = OrderedDict() - current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER)) - current_port = current_config.get(CONF_PORT, yaml_config.get(CONF_PORT)) - current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME)) - current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD)) - fields[vol.Required(CONF_BROKER, default=current_broker)] = str - fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int) - fields[ - vol.Optional( - CONF_USERNAME, - description={"suggested_value": current_user}, - ) - ] = str - fields[ - vol.Optional( - CONF_PASSWORD, - description={"suggested_value": current_pass}, - ) - ] = str - return self.async_show_form( step_id="broker", data_schema=vol.Schema(fields), @@ -212,53 +197,61 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): ) async def async_step_options( - self, user_input: dict[str, Any] | None = None + self, user_input: ConfigType | None = None ) -> FlowResult: """Manage the MQTT options.""" - mqtt_data = get_mqtt_data(self.hass, True) errors = {} current_config = self.config_entry.data - yaml_config = mqtt_data.config or {} - options_config: dict[str, Any] = {} - if user_input is not None: - bad_birth = False - bad_will = False + yaml_config = get_mqtt_data(self.hass, True).config or {} + options_config: ConfigType = {} + bad_input: bool = False + def _birth_will(birt_or_will: str) -> dict: + """Return the user input for birth or will.""" + assert user_input + return { + ATTR_TOPIC: user_input[f"{birt_or_will}_topic"], + ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""), + ATTR_QOS: user_input[f"{birt_or_will}_qos"], + ATTR_RETAIN: user_input[f"{birt_or_will}_retain"], + } + + def _validate( + field: str, values: ConfigType, error_code: str, schema: Callable + ): + """Validate the user input.""" + nonlocal bad_input + try: + option_values = schema(values) + options_config[field] = option_values + except vol.Invalid: + errors["base"] = error_code + bad_input = True + + if user_input is not None: + # validate input + options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY] if "birth_topic" in user_input: - birth_message = { - ATTR_TOPIC: user_input["birth_topic"], - ATTR_PAYLOAD: user_input.get("birth_payload", ""), - ATTR_QOS: user_input["birth_qos"], - ATTR_RETAIN: user_input["birth_retain"], - } - try: - birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message) - options_config[CONF_BIRTH_MESSAGE] = birth_message - except vol.Invalid: - errors["base"] = "bad_birth" - bad_birth = True + _validate( + CONF_BIRTH_MESSAGE, + _birth_will("birth"), + "bad_birth", + MQTT_WILL_BIRTH_SCHEMA, + ) if not user_input["birth_enable"]: options_config[CONF_BIRTH_MESSAGE] = {} if "will_topic" in user_input: - will_message = { - ATTR_TOPIC: user_input["will_topic"], - ATTR_PAYLOAD: user_input.get("will_payload", ""), - ATTR_QOS: user_input["will_qos"], - ATTR_RETAIN: user_input["will_retain"], - } - try: - will_message = MQTT_WILL_BIRTH_SCHEMA(will_message) - options_config[CONF_WILL_MESSAGE] = will_message - except vol.Invalid: - errors["base"] = "bad_will" - bad_will = True + _validate( + CONF_WILL_MESSAGE, + _birth_will("will"), + "bad_will", + MQTT_WILL_BIRTH_SCHEMA, + ) if not user_input["will_enable"]: options_config[CONF_WILL_MESSAGE] = {} - options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY] - - if not bad_birth and not bad_will: + if not bad_input: updated_config = {} updated_config.update(self.broker_config) updated_config.update(options_config) @@ -285,6 +278,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY) ) + # build form fields: OrderedDict[vol.Marker, Any] = OrderedDict() fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool @@ -338,28 +332,66 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): ) -def try_connection( +async def async_get_broker_settings( + hass: HomeAssistant, + fields: OrderedDict[Any, Any], yaml_config: ConfigType, - broker: str, - port: int, - username: str | None, - password: str | None, - protocol: str = "3.1", + entry_config: MappingProxyType[str, Any] | None, + user_input: ConfigType | None, + validated_user_input: ConfigType, + errors: dict[str, str], +) -> bool: + """Build the config flow schema to collect the broker settings. + + Returns True when settings are collected successfully. + """ + user_input_basic: ConfigType = ConfigType() + current_config = entry_config.copy() if entry_config is not None else ConfigType() + + if user_input is not None: + validated_user_input.update(user_input) + return True + + # Update the current settings the the new posted data to fill the defaults + current_config.update(user_input_basic) + + # Get default settings (if any) + current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER)) + current_port = current_config.get( + CONF_PORT, yaml_config.get(CONF_PORT, DEFAULT_PORT) + ) + current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME)) + current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD)) + + # Build form + fields[vol.Required(CONF_BROKER, default=current_broker)] = str + fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int) + fields[ + vol.Optional( + CONF_USERNAME, + description={"suggested_value": current_user}, + ) + ] = str + fields[ + vol.Optional( + CONF_PASSWORD, + description={"suggested_value": current_pass}, + ) + ] = str + + # Show form + return False + + +def try_connection( + user_input: ConfigType, ) -> bool: """Test if we can connect to an MQTT broker.""" # We don't import on the top because some integrations # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel - # Get the config from configuration.yaml - entry_config = { - CONF_BROKER: broker, - CONF_PORT: port, - CONF_USERNAME: username, - CONF_PASSWORD: password, - CONF_PROTOCOL: protocol, - } - client = MqttClientSetup({**yaml_config, **entry_config}).client + client = MqttClientSetup(user_input).client result: queue.Queue[bool] = queue.Queue(maxsize=1) @@ -369,7 +401,7 @@ def try_connection( client.on_connect = on_connect - client.connect_async(broker, port) + client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT]) client.loop_start() try: diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index 93410f0c792..d266ed231ba 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -40,6 +40,7 @@ DEFAULT_ENCODING = "utf-8" DEFAULT_QOS = 0 DEFAULT_PAYLOAD_AVAILABLE = "online" DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline" +DEFAULT_PORT = 1883 DEFAULT_RETAIN = False DEFAULT_BIRTH = { @@ -67,6 +68,8 @@ PAYLOAD_NONE = "None" PROTOCOL_31 = "3.1" PROTOCOL_311 = "3.1.1" +DEFAULT_PROTOCOL = PROTOCOL_311 + PLATFORMS = [ Platform.ALARM_CONTROL_PANEL, Platform.BINARY_SENSOR, diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index 5d67b34db5d..631f373316b 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -188,15 +188,12 @@ async def test_manual_config_set( # Check we tried the connection, with precedence for config entry settings mock_try_connection.assert_called_once_with( { - "broker": "bla", + "broker": "127.0.0.1", + "protocol": "3.1.1", "keepalive": 60, "discovery_prefix": "homeassistant", - "protocol": "3.1.1", + "port": 1883, }, - "127.0.0.1", - 1883, - None, - None, ) # Check config entry got setup assert len(mock_finish_setup.mock_calls) == 1 @@ -291,6 +288,44 @@ async def test_hassio_confirm(hass, mock_try_connection_success, mock_finish_set assert len(mock_finish_setup.mock_calls) == 1 +async def test_hassio_cannot_connect( + hass, mock_try_connection_time_out, mock_finish_setup +): + """Test a config flow is aborted when a connection was not successful.""" + mock_try_connection.return_value = True + + result = await hass.config_entries.flow.async_init( + "mqtt", + data=HassioServiceInfo( + config={ + "addon": "Mock Addon", + "host": "mock-broker", + "port": 1883, + "username": "mock-user", + "password": "mock-pass", + "protocol": "3.1.1", # Set by the addon's discovery, ignored by HA + "ssl": False, # Set by the addon's discovery, ignored by HA + } + ), + context={"source": config_entries.SOURCE_HASSIO}, + ) + assert result["type"] == "form" + assert result["step_id"] == "hassio_confirm" + assert result["description_placeholders"] == {"addon": "Mock Addon"} + + mock_try_connection_time_out.reset_mock() + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"discovery": True} + ) + + assert result["type"] == "form" + assert result["errors"]["base"] == "cannot_connect" + # Check we tried the connection + assert len(mock_try_connection_time_out.mock_calls) + # Check config entry got setup + assert len(mock_finish_setup.mock_calls) == 0 + + @patch( "homeassistant.config.async_hass_config_yaml", AsyncMock(return_value={}), @@ -299,7 +334,7 @@ async def test_option_flow( hass, mqtt_mock_entry_no_yaml_config, mock_try_connection, - mock_reload_after_entry_update, + caplog, ): """Test config flow options.""" mqtt_mock = await mqtt_mock_entry_no_yaml_config() @@ -372,7 +407,10 @@ async def test_option_flow( await hass.async_block_till_done() assert config_entry.title == "another-broker" # assert that the entry was reloaded with the new config - assert mock_reload_after_entry_update.call_count == 1 + assert ( + "" + in caplog.text + ) async def test_disable_birth_will(