"""Config flow for MQTT.""" from __future__ import annotations import asyncio from collections import OrderedDict from collections.abc import Callable, Mapping from copy import deepcopy from dataclasses import dataclass from enum import IntEnum import json import logging import queue from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError from types import MappingProxyType from typing import TYPE_CHECKING, Any, cast from uuid import uuid4 from cryptography.hazmat.primitives.serialization import ( Encoding, NoEncryption, PrivateFormat, load_der_private_key, load_pem_private_key, ) from cryptography.x509 import load_der_x509_certificate, load_pem_x509_certificate import voluptuous as vol import yaml from homeassistant.components.file_upload import process_uploaded_file from homeassistant.components.hassio import AddonError, AddonManager, AddonState from homeassistant.components.sensor import ( CONF_STATE_CLASS, DEVICE_CLASS_UNITS, SensorDeviceClass, SensorStateClass, ) from homeassistant.components.switch import SwitchDeviceClass from homeassistant.config_entries import ( SOURCE_RECONFIGURE, ConfigEntry, ConfigFlow, ConfigFlowResult, ConfigSubentryFlow, OptionsFlow, SubentryFlowResult, ) from homeassistant.const import ( ATTR_CONFIGURATION_URL, ATTR_HW_VERSION, ATTR_MODEL, ATTR_MODEL_ID, ATTR_NAME, ATTR_SW_VERSION, CONF_CLIENT_ID, CONF_DEVICE, CONF_DEVICE_CLASS, CONF_DISCOVERY, CONF_HOST, CONF_NAME, CONF_OPTIMISTIC, CONF_PASSWORD, CONF_PAYLOAD, CONF_PLATFORM, CONF_PORT, CONF_PROTOCOL, CONF_UNIQUE_ID, CONF_UNIT_OF_MEASUREMENT, CONF_USERNAME, CONF_VALUE_TEMPLATE, ) from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import AbortFlow, SectionConfig, section from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.hassio import is_hassio from homeassistant.helpers.json import json_dumps from homeassistant.helpers.selector import ( BooleanSelector, FileSelector, FileSelectorConfig, NumberSelector, NumberSelectorConfig, NumberSelectorMode, SelectOptionDict, Selector, SelectSelector, SelectSelectorConfig, SelectSelectorMode, TemplateSelector, TemplateSelectorConfig, TextSelector, TextSelectorConfig, TextSelectorType, ) from homeassistant.helpers.service_info.hassio import HassioServiceInfo from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads from .addon import get_addon_manager from .client import MqttClientSetup from .const import ( ATTR_PAYLOAD, ATTR_QOS, ATTR_RETAIN, ATTR_TOPIC, CONF_AVAILABILITY_TEMPLATE, CONF_AVAILABILITY_TOPIC, CONF_BIRTH_MESSAGE, CONF_BROKER, CONF_CERTIFICATE, CONF_CLIENT_CERT, CONF_CLIENT_KEY, CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, CONF_DISCOVERY_PREFIX, CONF_ENTITY_PICTURE, CONF_EXPIRE_AFTER, CONF_KEEPALIVE, CONF_LAST_RESET_VALUE_TEMPLATE, CONF_OPTIONS, CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, CONF_SUGGESTED_DISPLAY_PRECISION, CONF_TLS_INSECURE, CONF_TRANSPORT, CONF_WILL_MESSAGE, CONF_WS_HEADERS, CONF_WS_PATH, CONFIG_ENTRY_MINOR_VERSION, CONFIG_ENTRY_VERSION, DEFAULT_BIRTH, DEFAULT_DISCOVERY, DEFAULT_ENCODING, DEFAULT_KEEPALIVE, DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE, DEFAULT_PORT, DEFAULT_PREFIX, DEFAULT_PROTOCOL, DEFAULT_QOS, DEFAULT_TRANSPORT, DEFAULT_WILL, DEFAULT_WS_PATH, DOMAIN, SUPPORTED_PROTOCOLS, TRANSPORT_TCP, TRANSPORT_WEBSOCKETS, Platform, ) from .models import MqttAvailabilityData, MqttDeviceData, MqttSubentryData from .util import ( async_create_certificate_temp_files, get_file_path, learn_more_url, valid_birth_will, valid_publish_topic, valid_subscribe_topic, valid_subscribe_topic_template, ) _LOGGER = logging.getLogger(__name__) ADDON_SETUP_TIMEOUT = 5 ADDON_SETUP_TIMEOUT_ROUNDS = 5 CONF_CLIENT_KEY_PASSWORD = "client_key_password" MQTT_TIMEOUT = 5 ADVANCED_OPTIONS = "advanced_options" SET_CA_CERT = "set_ca_cert" SET_CLIENT_CERT = "set_client_cert" BOOLEAN_SELECTOR = BooleanSelector() TEXT_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT)) URL_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.URL)) PUBLISH_TOPIC_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT)) PORT_SELECTOR = vol.All( NumberSelector(NumberSelectorConfig(mode=NumberSelectorMode.BOX, min=1, max=65535)), vol.Coerce(int), ) PASSWORD_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.PASSWORD)) QOS_SELECTOR = NumberSelector( NumberSelectorConfig(mode=NumberSelectorMode.BOX, min=0, max=2) ) KEEPALIVE_SELECTOR = vol.All( NumberSelector( NumberSelectorConfig( mode=NumberSelectorMode.BOX, min=15, step="any", unit_of_measurement="sec" ) ), vol.Coerce(int), ) PROTOCOL_SELECTOR = SelectSelector( SelectSelectorConfig( options=SUPPORTED_PROTOCOLS, mode=SelectSelectorMode.DROPDOWN, ) ) SUPPORTED_TRANSPORTS = [ SelectOptionDict(value=TRANSPORT_TCP, label="TCP"), SelectOptionDict(value=TRANSPORT_WEBSOCKETS, label="WebSocket"), ] TRANSPORT_SELECTOR = SelectSelector( SelectSelectorConfig( options=SUPPORTED_TRANSPORTS, mode=SelectSelectorMode.DROPDOWN, ) ) WS_HEADERS_SELECTOR = TextSelector( TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True) ) CA_VERIFICATION_MODES = [ "off", "auto", "custom", ] BROKER_VERIFICATION_SELECTOR = SelectSelector( SelectSelectorConfig( options=CA_VERIFICATION_MODES, mode=SelectSelectorMode.DROPDOWN, translation_key=SET_CA_CERT, ) ) # mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html CA_CERT_UPLOAD_SELECTOR = FileSelector( FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-ca-cert") ) CERT_UPLOAD_SELECTOR = FileSelector( FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-user-cert") ) KEY_UPLOAD_SELECTOR = FileSelector( FileSelectorConfig(accept=".pem,.key,.der,.pk8,application/pkcs8") ) # Subentry selectors SUBENTRY_PLATFORMS = [Platform.NOTIFY, Platform.SENSOR, Platform.SWITCH] SUBENTRY_PLATFORM_SELECTOR = SelectSelector( SelectSelectorConfig( options=[platform.value for platform in SUBENTRY_PLATFORMS], mode=SelectSelectorMode.DROPDOWN, translation_key=CONF_PLATFORM, ) ) TEMPLATE_SELECTOR = TemplateSelector(TemplateSelectorConfig()) SUBENTRY_AVAILABILITY_SCHEMA = vol.Schema( { vol.Optional(CONF_AVAILABILITY_TOPIC): TEXT_SELECTOR, vol.Optional(CONF_AVAILABILITY_TEMPLATE): TEMPLATE_SELECTOR, vol.Optional( CONF_PAYLOAD_AVAILABLE, default=DEFAULT_PAYLOAD_AVAILABLE ): TEXT_SELECTOR, vol.Optional( CONF_PAYLOAD_NOT_AVAILABLE, default=DEFAULT_PAYLOAD_NOT_AVAILABLE ): TEXT_SELECTOR, } ) # Sensor specific selectors SENSOR_DEVICE_CLASS_SELECTOR = SelectSelector( SelectSelectorConfig( options=[device_class.value for device_class in SensorDeviceClass], mode=SelectSelectorMode.DROPDOWN, translation_key="device_class_sensor", sort=True, ) ) SENSOR_STATE_CLASS_SELECTOR = SelectSelector( SelectSelectorConfig( options=[device_class.value for device_class in SensorStateClass], mode=SelectSelectorMode.DROPDOWN, translation_key=CONF_STATE_CLASS, ) ) OPTIONS_SELECTOR = SelectSelector( SelectSelectorConfig( options=[], custom_value=True, multiple=True, ) ) SUGGESTED_DISPLAY_PRECISION_SELECTOR = NumberSelector( NumberSelectorConfig(mode=NumberSelectorMode.BOX, min=0, max=9) ) EXPIRE_AFTER_SELECTOR = NumberSelector( NumberSelectorConfig(mode=NumberSelectorMode.BOX, min=0) ) # Switch specific selectors SWITCH_DEVICE_CLASS_SELECTOR = SelectSelector( SelectSelectorConfig( options=[device_class.value for device_class in SwitchDeviceClass], mode=SelectSelectorMode.DROPDOWN, translation_key="device_class_switch", ) ) @callback def validate_sensor_platform_config( config: dict[str, Any], ) -> dict[str, str]: """Validate the sensor options, state and device class config.""" errors: dict[str, str] = {} # Only allow `options` to be set for `enum` sensors # to limit the possible sensor values if config.get(CONF_OPTIONS) is not None: if config.get(CONF_STATE_CLASS) or config.get(CONF_UNIT_OF_MEASUREMENT): errors[CONF_OPTIONS] = "options_not_allowed_with_state_class_or_uom" if (device_class := config.get(CONF_DEVICE_CLASS)) != SensorDeviceClass.ENUM: errors[CONF_DEVICE_CLASS] = "options_device_class_enum" if ( (device_class := config.get(CONF_DEVICE_CLASS)) == SensorDeviceClass.ENUM and errors is not None and CONF_OPTIONS not in config ): errors[CONF_OPTIONS] = "options_with_enum_device_class" if ( device_class in DEVICE_CLASS_UNITS and (unit_of_measurement := config.get(CONF_UNIT_OF_MEASUREMENT)) is None and errors is not None ): # Do not allow an empty unit of measurement in a subentry data flow errors[CONF_UNIT_OF_MEASUREMENT] = "uom_required_for_device_class" return errors if ( device_class is not None and device_class in DEVICE_CLASS_UNITS and unit_of_measurement not in DEVICE_CLASS_UNITS[device_class] ): errors[CONF_UNIT_OF_MEASUREMENT] = "invalid_uom" return errors @dataclass(frozen=True, kw_only=True) class PlatformField: """Stores a platform config field schema, required flag and validator.""" selector: Selector[Any] | Callable[..., Selector[Any]] required: bool validator: Callable[..., Any] error: str | None = None default: str | int | vol.Undefined = vol.UNDEFINED exclude_from_reconfig: bool = False conditions: tuple[dict[str, Any], ...] | None = None custom_filtering: bool = False section: str | None = None @callback def unit_of_measurement_selector(user_data: dict[str, Any | None]) -> Selector: """Return a context based unit of measurement selector.""" if ( user_data is None or (device_class := user_data.get(CONF_DEVICE_CLASS)) is None or device_class not in DEVICE_CLASS_UNITS ): return TEXT_SELECTOR return SelectSelector( SelectSelectorConfig( options=[str(uom) for uom in DEVICE_CLASS_UNITS[device_class]], sort=True, custom_value=True, ) ) COMMON_ENTITY_FIELDS = { CONF_PLATFORM: PlatformField( selector=SUBENTRY_PLATFORM_SELECTOR, required=True, validator=str, exclude_from_reconfig=True, ), CONF_NAME: PlatformField( selector=TEXT_SELECTOR, required=False, validator=str, exclude_from_reconfig=True, ), CONF_ENTITY_PICTURE: PlatformField( selector=TEXT_SELECTOR, required=False, validator=cv.url, error="invalid_url" ), } PLATFORM_ENTITY_FIELDS = { Platform.NOTIFY.value: {}, Platform.SENSOR.value: { CONF_DEVICE_CLASS: PlatformField( selector=SENSOR_DEVICE_CLASS_SELECTOR, required=False, validator=str ), CONF_STATE_CLASS: PlatformField( selector=SENSOR_STATE_CLASS_SELECTOR, required=False, validator=str ), CONF_UNIT_OF_MEASUREMENT: PlatformField( selector=unit_of_measurement_selector, required=False, validator=str, custom_filtering=True, ), CONF_SUGGESTED_DISPLAY_PRECISION: PlatformField( selector=SUGGESTED_DISPLAY_PRECISION_SELECTOR, required=False, validator=cv.positive_int, section="advanced_settings", ), CONF_OPTIONS: PlatformField( selector=OPTIONS_SELECTOR, required=False, validator=cv.ensure_list, conditions=({"device_class": "enum"},), ), }, Platform.SWITCH.value: { CONF_DEVICE_CLASS: PlatformField( selector=SWITCH_DEVICE_CLASS_SELECTOR, required=False, validator=str ), }, } PLATFORM_MQTT_FIELDS = { Platform.NOTIFY.value: { CONF_COMMAND_TOPIC: PlatformField( selector=TEXT_SELECTOR, required=True, validator=valid_publish_topic, error="invalid_publish_topic", ), CONF_COMMAND_TEMPLATE: PlatformField( selector=TEMPLATE_SELECTOR, required=False, validator=cv.template, error="invalid_template", ), CONF_RETAIN: PlatformField( selector=BOOLEAN_SELECTOR, required=False, validator=bool ), }, Platform.SENSOR.value: { CONF_STATE_TOPIC: PlatformField( selector=TEXT_SELECTOR, required=True, validator=valid_subscribe_topic, error="invalid_subscribe_topic", ), CONF_VALUE_TEMPLATE: PlatformField( selector=TEMPLATE_SELECTOR, required=False, validator=cv.template, error="invalid_template", ), CONF_LAST_RESET_VALUE_TEMPLATE: PlatformField( selector=TEMPLATE_SELECTOR, required=False, validator=cv.template, error="invalid_template", conditions=({CONF_STATE_CLASS: "total"},), ), CONF_EXPIRE_AFTER: PlatformField( selector=EXPIRE_AFTER_SELECTOR, required=False, validator=cv.positive_int, section="advanced_settings", ), }, Platform.SWITCH.value: { CONF_COMMAND_TOPIC: PlatformField( selector=TEXT_SELECTOR, required=True, validator=valid_publish_topic, error="invalid_publish_topic", ), CONF_COMMAND_TEMPLATE: PlatformField( selector=TEMPLATE_SELECTOR, required=False, validator=cv.template, error="invalid_template", ), CONF_STATE_TOPIC: PlatformField( selector=TEXT_SELECTOR, required=False, validator=valid_subscribe_topic, error="invalid_subscribe_topic", ), CONF_VALUE_TEMPLATE: PlatformField( selector=TEMPLATE_SELECTOR, required=False, validator=cv.template, error="invalid_template", ), CONF_RETAIN: PlatformField( selector=BOOLEAN_SELECTOR, required=False, validator=bool ), CONF_OPTIMISTIC: PlatformField( selector=BOOLEAN_SELECTOR, required=False, validator=bool ), }, } ENTITY_CONFIG_VALIDATOR: dict[ str, Callable[[dict[str, Any]], dict[str, str]] | None, ] = { Platform.NOTIFY.value: None, Platform.SENSOR.value: validate_sensor_platform_config, Platform.SWITCH.value: None, } MQTT_DEVICE_PLATFORM_FIELDS = { ATTR_NAME: PlatformField(selector=TEXT_SELECTOR, required=False, validator=str), ATTR_SW_VERSION: PlatformField( selector=TEXT_SELECTOR, required=False, validator=str ), ATTR_HW_VERSION: PlatformField( selector=TEXT_SELECTOR, required=False, validator=str ), ATTR_MODEL: PlatformField(selector=TEXT_SELECTOR, required=False, validator=str), ATTR_MODEL_ID: PlatformField(selector=TEXT_SELECTOR, required=False, validator=str), ATTR_CONFIGURATION_URL: PlatformField( selector=TEXT_SELECTOR, required=False, validator=cv.url, error="invalid_url" ), CONF_QOS: PlatformField( selector=QOS_SELECTOR, required=False, validator=int, default=DEFAULT_QOS, section="mqtt_settings", ), } REAUTH_SCHEMA = vol.Schema( { vol.Required(CONF_USERNAME): TEXT_SELECTOR, vol.Required(CONF_PASSWORD): PASSWORD_SELECTOR, } ) PWD_NOT_CHANGED = "__**password_not_changed**__" @callback def update_password_from_user_input( entry_password: str | None, user_input: dict[str, Any] ) -> dict[str, Any]: """Update the password if the entry has been updated. As we want to avoid reflecting the stored password in the UI, we replace the suggested value in the UI with a sentitel, and we change it back here if it was changed. """ substituted_used_data = dict(user_input) # Take out the password submitted user_password: str | None = substituted_used_data.pop(CONF_PASSWORD, None) # Only add the password if it has changed. # If the sentinel password is submitted, we replace that with our current # password from the config entry data. password_changed = user_password is not None and user_password != PWD_NOT_CHANGED password = user_password if password_changed else entry_password if password is not None: substituted_used_data[CONF_PASSWORD] = password return substituted_used_data @callback def validate_field( field: str, validator: Callable[..., Any], user_input: dict[str, Any] | None, errors: dict[str, str], error: str, ) -> None: """Validate a single field.""" if user_input is None or field not in user_input: return try: validator(user_input[field]) except (ValueError, vol.Invalid): errors[field] = error @callback def _check_conditions( platform_field: PlatformField, component_data: dict[str, Any] | None = None ) -> bool: """Only include field if one of conditions match, or no conditions are set.""" if platform_field.conditions is None or component_data is None: return True return any( all(component_data.get(key) == value for key, value in condition.items()) for condition in platform_field.conditions ) @callback def calculate_merged_config( merged_user_input: dict[str, Any], data_schema_fields: dict[str, PlatformField], component_data: dict[str, Any], ) -> dict[str, Any]: """Calculate merged config.""" base_schema_fields = { key for key, platform_field in data_schema_fields.items() if _check_conditions(platform_field, component_data) } - set(merged_user_input) return { key: value for key, value in component_data.items() if key not in base_schema_fields } | merged_user_input @callback def validate_user_input( user_input: dict[str, Any], data_schema_fields: dict[str, PlatformField], *, component_data: dict[str, Any] | None = None, config_validator: Callable[[dict[str, Any]], dict[str, str]] | None = None, ) -> tuple[dict[str, Any], dict[str, str]]: """Validate user input.""" errors: dict[str, str] = {} # Merge sections merged_user_input: dict[str, Any] = {} for key, value in user_input.items(): if isinstance(value, dict): merged_user_input.update(value) else: merged_user_input[key] = value for field, value in merged_user_input.items(): validator = data_schema_fields[field].validator try: validator(value) except (ValueError, vol.Invalid): errors[field] = data_schema_fields[field].error or "invalid_input" if config_validator is not None: if TYPE_CHECKING: assert component_data is not None errors |= config_validator( calculate_merged_config( merged_user_input, data_schema_fields, component_data ), ) return merged_user_input, errors @callback def data_schema_from_fields( data_schema_fields: dict[str, PlatformField], reconfig: bool, component_data: dict[str, Any] | None = None, user_input: dict[str, Any] | None = None, device_data: MqttDeviceData | None = None, ) -> vol.Schema: """Generate custom data schema from platform fields or device data.""" if device_data is not None: component_data_with_user_input: dict[str, Any] | None = dict(device_data) if TYPE_CHECKING: assert component_data_with_user_input is not None component_data_with_user_input.update( component_data_with_user_input.pop("mqtt_settings", {}) ) else: component_data_with_user_input = deepcopy(component_data) if component_data_with_user_input is not None and user_input is not None: component_data_with_user_input |= user_input sections: dict[str | None, None] = { field_details.section: None for field_details in data_schema_fields.values() } data_schema: dict[Any, Any] = {} all_data_element_options: set[Any] = set() no_reconfig_options: set[Any] = set() for schema_section in sections: data_schema_element = { vol.Required(field_name, default=field_details.default) if field_details.required else vol.Optional( field_name, default=field_details.default ): field_details.selector(component_data_with_user_input) # type: ignore[operator] if field_details.custom_filtering else field_details.selector for field_name, field_details in data_schema_fields.items() if field_details.section == schema_section and (not field_details.exclude_from_reconfig or not reconfig) and _check_conditions(field_details, component_data_with_user_input) } data_element_options = set(data_schema_element) all_data_element_options |= data_element_options no_reconfig_options |= { field_name for field_name, field_details in data_schema_fields.items() if field_details.section == schema_section and field_details.exclude_from_reconfig } if schema_section is None: data_schema.update(data_schema_element) continue collapsed = ( not any( (default := data_schema_fields[str(option)].default) is vol.UNDEFINED or component_data_with_user_input[str(option)] != default for option in data_element_options if option in component_data_with_user_input ) if component_data_with_user_input is not None else True ) data_schema[vol.Optional(schema_section)] = section( vol.Schema(data_schema_element), SectionConfig({"collapsed": collapsed}) ) # Reset all fields from the component_data not in the schema if component_data: filtered_fields = ( set(data_schema_fields) - all_data_element_options - no_reconfig_options ) for field in filtered_fields: if field in component_data: del component_data[field] return vol.Schema(data_schema) class FlowHandler(ConfigFlow, domain=DOMAIN): """Handle a config flow.""" # Can be bumped to version 2.1 with HA Core 2026.1.0 VERSION = CONFIG_ENTRY_VERSION # 1 MINOR_VERSION = CONFIG_ENTRY_MINOR_VERSION # 2 _hassio_discovery: dict[str, Any] | None = None _addon_manager: AddonManager def __init__(self) -> None: """Set up flow instance.""" self.install_task: asyncio.Task | None = None self.start_task: asyncio.Task | None = None @classmethod @callback def async_get_supported_subentry_types( cls, config_entry: ConfigEntry ) -> dict[str, type[ConfigSubentryFlow]]: """Return subentries supported by this handler.""" return {CONF_DEVICE: MQTTSubentryFlowHandler} @staticmethod @callback def async_get_options_flow( config_entry: ConfigEntry, ) -> MQTTOptionsFlowHandler: """Get the options flow for this handler.""" return MQTTOptionsFlowHandler() async def _async_install_addon(self) -> None: """Install the Mosquitto Mqtt broker add-on.""" addon_manager: AddonManager = get_addon_manager(self.hass) await addon_manager.async_schedule_install_addon() async def async_step_install_failed( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Add-on installation failed.""" return self.async_abort( reason="addon_install_failed", description_placeholders={"addon": self._addon_manager.addon_name}, ) async def async_step_install_addon( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Install Mosquitto Broker add-on.""" if self.install_task is None: self.install_task = self.hass.async_create_task(self._async_install_addon()) if not self.install_task.done(): return self.async_show_progress( step_id="install_addon", progress_action="install_addon", progress_task=self.install_task, ) try: await self.install_task except AddonError as err: _LOGGER.error(err) return self.async_show_progress_done(next_step_id="install_failed") finally: self.install_task = None return self.async_show_progress_done(next_step_id="start_addon") async def async_step_start_failed( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Add-on start failed.""" return self.async_abort( reason="addon_start_failed", description_placeholders={"addon": self._addon_manager.addon_name}, ) async def async_step_start_addon( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Start Mosquitto Broker add-on.""" if not self.start_task: self.start_task = self.hass.async_create_task(self._async_start_addon()) if not self.start_task.done(): return self.async_show_progress( step_id="start_addon", progress_action="start_addon", progress_task=self.start_task, ) try: await self.start_task except AddonError as err: _LOGGER.error(err) return self.async_show_progress_done(next_step_id="start_failed") finally: self.start_task = None return self.async_show_progress_done(next_step_id="setup_entry_from_discovery") async def _async_get_config_and_try(self) -> dict[str, Any] | None: """Get the MQTT add-on discovery info and try the connection.""" if self._hassio_discovery is not None: return self._hassio_discovery addon_manager: AddonManager = get_addon_manager(self.hass) try: addon_discovery_config = ( await addon_manager.async_get_addon_discovery_info() ) config: dict[str, Any] = { CONF_BROKER: addon_discovery_config[CONF_HOST], CONF_PORT: addon_discovery_config[CONF_PORT], CONF_USERNAME: addon_discovery_config.get(CONF_USERNAME), CONF_PASSWORD: addon_discovery_config.get(CONF_PASSWORD), CONF_DISCOVERY: DEFAULT_DISCOVERY, } except AddonError: # We do not have discovery information yet return None if await self.hass.async_add_executor_job( try_connection, config, ): self._hassio_discovery = config return config return None async def _async_start_addon(self) -> None: """Start the Mosquitto Broker add-on.""" addon_manager: AddonManager = get_addon_manager(self.hass) await addon_manager.async_schedule_start_addon() # Sleep some seconds to let the add-on start properly before connecting. for _ in range(ADDON_SETUP_TIMEOUT_ROUNDS): await asyncio.sleep(ADDON_SETUP_TIMEOUT) # Finish setup using discovery info to test the connection if await self._async_get_config_and_try(): break else: raise AddonError( translation_domain=DOMAIN, translation_key="addon_start_failed", translation_placeholders={"addon": addon_manager.addon_name}, ) async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle a flow initialized by the user.""" if is_hassio(self.hass): # Offer to set up broker add-on if supervisor is available self._addon_manager = get_addon_manager(self.hass) return self.async_show_menu( step_id="user", menu_options=["addon", "broker"], description_placeholders={"addon": self._addon_manager.addon_name}, ) # Start up a flow for manual setup return await self.async_step_broker() async def async_step_setup_entry_from_discovery( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Set up mqtt entry from discovery info.""" if (config := await self._async_get_config_and_try()) is not None: return self.async_create_entry( title=self._addon_manager.addon_name, data=config, ) raise AbortFlow( "addon_connection_failed", description_placeholders={"addon": self._addon_manager.addon_name}, ) async def async_step_addon( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Install and start MQTT broker add-on.""" addon_manager = self._addon_manager try: addon_info = await addon_manager.async_get_addon_info() except AddonError as err: raise AbortFlow( "addon_info_failed", description_placeholders={"addon": self._addon_manager.addon_name}, ) from err if addon_info.state == AddonState.RUNNING: # Finish setup using discovery info return await self.async_step_setup_entry_from_discovery() if addon_info.state == AddonState.NOT_RUNNING: return await self.async_step_start_addon() # Install the add-on and start it return await self.async_step_install_addon() async def async_step_reauth( self, entry_data: Mapping[str, Any] ) -> ConfigFlowResult: """Handle re-authentication with MQTT broker.""" if is_hassio(self.hass): # Check if entry setup matches the add-on discovery config addon_manager = get_addon_manager(self.hass) try: addon_discovery_config = ( await addon_manager.async_get_addon_discovery_info() ) except AddonError: # Follow manual flow if we have an error pass else: # Check if the addon secrets need to be renewed. # This will repair the config entry, # in case the official Mosquitto Broker addon was re-installed. if ( entry_data[CONF_BROKER] == addon_discovery_config[CONF_HOST] and entry_data[CONF_PORT] == addon_discovery_config[CONF_PORT] and entry_data.get(CONF_USERNAME) == (username := addon_discovery_config.get(CONF_USERNAME)) and entry_data.get(CONF_PASSWORD) != (password := addon_discovery_config.get(CONF_PASSWORD)) ): _LOGGER.info( "Executing autorecovery %s add-on secrets", addon_manager.addon_name, ) return await self.async_step_reauth_confirm( user_input={CONF_USERNAME: username, CONF_PASSWORD: password} ) return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Confirm re-authentication with MQTT broker.""" errors: dict[str, str] = {} reauth_entry = self._get_reauth_entry() if user_input: substituted_used_data = update_password_from_user_input( reauth_entry.data.get(CONF_PASSWORD), user_input ) new_entry_data = {**reauth_entry.data, **substituted_used_data} if await self.hass.async_add_executor_job( try_connection, new_entry_data, ): return self.async_update_reload_and_abort( reauth_entry, data=new_entry_data ) errors["base"] = "invalid_auth" schema = self.add_suggested_values_to_schema( REAUTH_SCHEMA, { CONF_USERNAME: reauth_entry.data.get(CONF_USERNAME), CONF_PASSWORD: PWD_NOT_CHANGED, }, ) return self.async_show_form( step_id="reauth_confirm", data_schema=schema, errors=errors, ) async def async_step_broker( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Confirm the setup.""" errors: dict[str, str] = {} fields: OrderedDict[Any, Any] = OrderedDict() validated_user_input: dict[str, Any] = {} if is_reconfigure := (self.source == SOURCE_RECONFIGURE): reconfigure_entry = self._get_reconfigure_entry() if await async_get_broker_settings( self, fields, reconfigure_entry.data if is_reconfigure else None, user_input, validated_user_input, errors, ): if is_reconfigure: validated_user_input = update_password_from_user_input( reconfigure_entry.data.get(CONF_PASSWORD), validated_user_input ) can_connect = await self.hass.async_add_executor_job( try_connection, validated_user_input, ) if can_connect: if is_reconfigure: return self.async_update_reload_and_abort( reconfigure_entry, data=validated_user_input, ) return self.async_create_entry( title=validated_user_input[CONF_BROKER], data=validated_user_input, ) errors["base"] = "cannot_connect" return self.async_show_form( step_id="broker", data_schema=vol.Schema(fields), errors=errors ) async def async_step_reconfigure( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle a reconfiguration flow initialized by the user.""" return await self.async_step_broker() async def async_step_hassio( self, discovery_info: HassioServiceInfo ) -> ConfigFlowResult: """Receive a Hass.io discovery or process setup after addon install.""" await self._async_handle_discovery_without_unique_id() self._hassio_discovery = discovery_info.config return await self.async_step_hassio_confirm() async def async_step_hassio_confirm( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Confirm a Hass.io discovery.""" errors: dict[str, str] = {} if TYPE_CHECKING: assert self._hassio_discovery if user_input is not None: data: dict[str, Any] = self._hassio_discovery.copy() data[CONF_BROKER] = data.pop(CONF_HOST) can_connect = await self.hass.async_add_executor_job( try_connection, data, ) if can_connect: return self.async_create_entry( title=data["addon"], data={ CONF_BROKER: data[CONF_BROKER], CONF_PORT: data[CONF_PORT], CONF_USERNAME: data.get(CONF_USERNAME), CONF_PASSWORD: data.get(CONF_PASSWORD), CONF_DISCOVERY: DEFAULT_DISCOVERY, }, ) errors["base"] = "cannot_connect" return self.async_show_form( step_id="hassio_confirm", description_placeholders={"addon": self._hassio_discovery["addon"]}, errors=errors, ) class MQTTOptionsFlowHandler(OptionsFlow): """Handle MQTT options.""" async def async_step_init(self, user_input: None = None) -> ConfigFlowResult: """Manage the MQTT options.""" return await self.async_step_options() async def async_step_options( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Manage the MQTT options.""" errors = {} options_config: dict[str, Any] = dict(self.config_entry.options) bad_input: bool = False def _birth_will(birt_or_will: str) -> dict[str, Any]: """Return the user input for birth or will.""" if TYPE_CHECKING: assert user_input return { ATTR_TOPIC: user_input[f"{birt_or_will}_topic"], ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""), ATTR_QOS: user_input[f"{birt_or_will}_qos"], ATTR_RETAIN: user_input[f"{birt_or_will}_retain"], } def _validate( field: str, values: dict[str, Any], error_code: str, schema: Callable[[Any], Any], ) -> None: """Validate the user input.""" nonlocal bad_input try: option_values = schema(values) options_config[field] = option_values except vol.Invalid: errors["base"] = error_code bad_input = True if user_input is not None: # validate input options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY] _validate( CONF_DISCOVERY_PREFIX, user_input[CONF_DISCOVERY_PREFIX], "bad_discovery_prefix", valid_publish_topic, ) if "birth_topic" in user_input: _validate( CONF_BIRTH_MESSAGE, _birth_will("birth"), "bad_birth", valid_birth_will, ) if not user_input["birth_enable"]: options_config[CONF_BIRTH_MESSAGE] = {} if "will_topic" in user_input: _validate( CONF_WILL_MESSAGE, _birth_will("will"), "bad_will", valid_birth_will, ) if not user_input["will_enable"]: options_config[CONF_WILL_MESSAGE] = {} if not bad_input: return self.async_create_entry(data=options_config) birth = { **DEFAULT_BIRTH, **options_config.get(CONF_BIRTH_MESSAGE, {}), } will = { **DEFAULT_WILL, **options_config.get(CONF_WILL_MESSAGE, {}), } discovery = options_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY) discovery_prefix = options_config.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX) # build form fields: OrderedDict[vol.Marker, Any] = OrderedDict() fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = BOOLEAN_SELECTOR fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=discovery_prefix)] = ( PUBLISH_TOPIC_SELECTOR ) # Birth message is disabled if CONF_BIRTH_MESSAGE = {} fields[ vol.Optional( "birth_enable", default=CONF_BIRTH_MESSAGE not in options_config or options_config[CONF_BIRTH_MESSAGE] != {}, ) ] = BOOLEAN_SELECTOR fields[ vol.Optional( "birth_topic", description={"suggested_value": birth[ATTR_TOPIC]} ) ] = PUBLISH_TOPIC_SELECTOR fields[ vol.Optional( "birth_payload", description={"suggested_value": birth[CONF_PAYLOAD]} ) ] = TEXT_SELECTOR fields[vol.Optional("birth_qos", default=birth[ATTR_QOS])] = QOS_SELECTOR fields[vol.Optional("birth_retain", default=birth[ATTR_RETAIN])] = ( BOOLEAN_SELECTOR ) # Will message is disabled if CONF_WILL_MESSAGE = {} fields[ vol.Optional( "will_enable", default=CONF_WILL_MESSAGE not in options_config or options_config[CONF_WILL_MESSAGE] != {}, ) ] = BOOLEAN_SELECTOR fields[ vol.Optional( "will_topic", description={"suggested_value": will[ATTR_TOPIC]} ) ] = PUBLISH_TOPIC_SELECTOR fields[ vol.Optional( "will_payload", description={"suggested_value": will[CONF_PAYLOAD]} ) ] = TEXT_SELECTOR fields[vol.Optional("will_qos", default=will[ATTR_QOS])] = QOS_SELECTOR fields[vol.Optional("will_retain", default=will[ATTR_RETAIN])] = ( BOOLEAN_SELECTOR ) return self.async_show_form( step_id="options", data_schema=vol.Schema(fields), errors=errors, last_step=True, ) class MQTTSubentryFlowHandler(ConfigSubentryFlow): """Handle MQTT subentry flow.""" _subentry_data: MqttSubentryData _component_id: str | None = None @callback def update_component_fields( self, data_schema_fields: dict[str, PlatformField], merged_user_input: dict[str, Any], ) -> None: """Update the componment fields.""" if TYPE_CHECKING: assert self._component_id is not None component_data = self._subentry_data["components"][self._component_id] # Remove the fields from the component data # if they are not in the schema and not in the user input config = calculate_merged_config( merged_user_input, data_schema_fields, component_data ) for field in ( field for field, platform_field in data_schema_fields.items() if field in (set(component_data) - set(config)) and not platform_field.exclude_from_reconfig ): component_data.pop(field) component_data.update(merged_user_input) @callback def generate_names(self) -> tuple[str, str]: """Generate the device and full entity name.""" if TYPE_CHECKING: assert self._component_id is not None device_name = self._subentry_data[CONF_DEVICE][CONF_NAME] if entity_name := self._subentry_data["components"][self._component_id].get( CONF_NAME ): full_entity_name: str = f"{device_name} {entity_name}" else: full_entity_name = device_name return device_name, full_entity_name @callback def get_suggested_values_from_component( self, data_schema: vol.Schema ) -> dict[str, Any]: """Get suggestions from component data based on the data schema.""" if TYPE_CHECKING: assert self._component_id is not None component_data = self._subentry_data["components"][self._component_id] return { field_key: self.get_suggested_values_from_component(value.schema) if isinstance(value, section) else component_data.get(field_key) for field_key, value in data_schema.schema.items() } async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Add a subentry.""" self._subentry_data = MqttSubentryData(device=MqttDeviceData(), components={}) return await self.async_step_device() async def async_step_reconfigure( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Reconfigure a subentry.""" reconfigure_subentry = self._get_reconfigure_subentry() self._subentry_data = cast( MqttSubentryData, deepcopy(dict(reconfigure_subentry.data)) ) return await self.async_step_summary_menu() async def async_step_device( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Add a new MQTT device.""" errors: dict[str, Any] = {} device_data = self._subentry_data[CONF_DEVICE] data_schema = data_schema_from_fields( MQTT_DEVICE_PLATFORM_FIELDS, device_data=device_data, reconfig=True, ) if user_input is not None: _, errors = validate_user_input(user_input, MQTT_DEVICE_PLATFORM_FIELDS) if not errors: self._subentry_data[CONF_DEVICE] = cast(MqttDeviceData, user_input) if self.source == SOURCE_RECONFIGURE: return await self.async_step_summary_menu() return await self.async_step_entity() data_schema = self.add_suggested_values_to_schema( data_schema, device_data if user_input is None else user_input ) return self.async_show_form( step_id=CONF_DEVICE, data_schema=data_schema, errors=errors, last_step=False, ) async def async_step_entity( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Add or edit an mqtt entity.""" errors: dict[str, str] = {} data_schema_fields = COMMON_ENTITY_FIELDS entity_name_label: str = "" platform_label: str = "" component_data: dict[str, Any] | None = None if reconfig := (self._component_id is not None): component_data = self._subentry_data["components"][self._component_id] name: str | None = component_data.get(CONF_NAME) platform_label = f"{self._subentry_data['components'][self._component_id][CONF_PLATFORM]} " entity_name_label = f" ({name})" if name is not None else "" data_schema = data_schema_from_fields(data_schema_fields, reconfig=reconfig) if user_input is not None: merged_user_input, errors = validate_user_input( user_input, data_schema_fields, component_data=component_data ) if not errors: if self._component_id is None: self._component_id = uuid4().hex self._subentry_data["components"].setdefault(self._component_id, {}) self.update_component_fields(data_schema_fields, merged_user_input) return await self.async_step_entity_platform_config() data_schema = self.add_suggested_values_to_schema(data_schema, user_input) elif self.source == SOURCE_RECONFIGURE and self._component_id is not None: data_schema = self.add_suggested_values_to_schema( data_schema, self.get_suggested_values_from_component(data_schema), ) device_name = self._subentry_data[CONF_DEVICE][CONF_NAME] return self.async_show_form( step_id="entity", data_schema=data_schema, description_placeholders={ "mqtt_device": device_name, "entity_name_label": entity_name_label, "platform_label": platform_label, }, errors=errors, last_step=False, ) def _show_update_or_delete_form(self, step_id: str) -> SubentryFlowResult: """Help selecting an entity to update or delete.""" device_name = self._subentry_data[CONF_DEVICE][CONF_NAME] entities = [ SelectOptionDict( value=key, label=f"{device_name} {component_data.get(CONF_NAME, '-')}" f" ({component_data[CONF_PLATFORM]})", ) for key, component_data in self._subentry_data["components"].items() ] data_schema = vol.Schema( { vol.Required("component"): SelectSelector( SelectSelectorConfig( options=entities, mode=SelectSelectorMode.LIST, ) ) } ) return self.async_show_form( step_id=step_id, data_schema=data_schema, last_step=False ) async def async_step_update_entity( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Select the entity to update.""" if user_input: self._component_id = user_input["component"] return await self.async_step_entity() if len(self._subentry_data["components"]) == 1: # Return first key self._component_id = next(iter(self._subentry_data["components"])) return await self.async_step_entity() return self._show_update_or_delete_form("update_entity") async def async_step_delete_entity( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Select the entity to delete.""" if user_input: del self._subentry_data["components"][user_input["component"]] return await self.async_step_summary_menu() return self._show_update_or_delete_form("delete_entity") async def async_step_entity_platform_config( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Configure platform entity details.""" if TYPE_CHECKING: assert self._component_id is not None component_data = self._subentry_data["components"][self._component_id] platform = component_data[CONF_PLATFORM] data_schema_fields = PLATFORM_ENTITY_FIELDS[platform] errors: dict[str, str] = {} data_schema = data_schema_from_fields( data_schema_fields, reconfig=bool( {field for field in data_schema_fields if field in component_data} ), component_data=component_data, user_input=user_input, ) if not data_schema.schema: return await self.async_step_mqtt_platform_config() if user_input is not None: # Test entity fields against the validator merged_user_input, errors = validate_user_input( user_input, data_schema_fields, component_data=component_data, config_validator=ENTITY_CONFIG_VALIDATOR[platform], ) if not errors: self.update_component_fields(data_schema_fields, merged_user_input) return await self.async_step_mqtt_platform_config() data_schema = self.add_suggested_values_to_schema(data_schema, user_input) else: data_schema = self.add_suggested_values_to_schema( data_schema, self.get_suggested_values_from_component(data_schema), ) device_name, full_entity_name = self.generate_names() return self.async_show_form( step_id="entity_platform_config", data_schema=data_schema, description_placeholders={ "mqtt_device": device_name, CONF_PLATFORM: platform, "entity": full_entity_name, "url": learn_more_url(platform), } | (user_input or {}), errors=errors, last_step=False, ) async def async_step_mqtt_platform_config( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Configure entity platform MQTT details.""" errors: dict[str, str] = {} if TYPE_CHECKING: assert self._component_id is not None component_data = self._subentry_data["components"][self._component_id] platform = component_data[CONF_PLATFORM] data_schema_fields = PLATFORM_MQTT_FIELDS[platform] data_schema = data_schema_from_fields( data_schema_fields, reconfig=bool( {field for field in data_schema_fields if field in component_data} ), component_data=component_data, ) if user_input is not None: # Test entity fields against the validator merged_user_input, errors = validate_user_input( user_input, data_schema_fields, component_data=component_data, config_validator=ENTITY_CONFIG_VALIDATOR[platform], ) if not errors: self.update_component_fields(data_schema_fields, merged_user_input) self._component_id = None if self.source == SOURCE_RECONFIGURE: return await self.async_step_summary_menu() return self._async_create_subentry() data_schema = self.add_suggested_values_to_schema(data_schema, user_input) else: data_schema = self.add_suggested_values_to_schema( data_schema, self.get_suggested_values_from_component(data_schema), ) device_name, full_entity_name = self.generate_names() return self.async_show_form( step_id="mqtt_platform_config", data_schema=data_schema, description_placeholders={ "mqtt_device": device_name, CONF_PLATFORM: platform, "entity": full_entity_name, "url": learn_more_url(platform), }, errors=errors, last_step=False, ) @callback def _async_create_subentry( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Create a subentry for a new MQTT device.""" device_name = self._subentry_data[CONF_DEVICE][CONF_NAME] component_data: dict[str, Any] = next( iter(self._subentry_data["components"].values()) ) platform = component_data[CONF_PLATFORM] entity_name: str | None if entity_name := component_data.get(CONF_NAME): full_entity_name: str = f"{device_name} {entity_name}" else: full_entity_name = device_name return self.async_create_entry( data=self._subentry_data, title=self._subentry_data[CONF_DEVICE][CONF_NAME], description_placeholders={ "entity": full_entity_name, CONF_PLATFORM: platform, }, ) async def async_step_availability( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Configure availability options.""" errors: dict[str, str] = {} validate_field( "availability_topic", valid_subscribe_topic, user_input, errors, "invalid_subscribe_topic", ) validate_field( "availability_template", valid_subscribe_topic_template, user_input, errors, "invalid_template", ) if not errors and user_input is not None: self._subentry_data.setdefault("availability", MqttAvailabilityData()) self._subentry_data["availability"] = cast(MqttAvailabilityData, user_input) return await self.async_step_summary_menu() data_schema = SUBENTRY_AVAILABILITY_SCHEMA data_schema = self.add_suggested_values_to_schema( data_schema, dict(self._subentry_data.setdefault("availability", {})) if self.source == SOURCE_RECONFIGURE else user_input, ) return self.async_show_form( step_id="availability", data_schema=data_schema, errors=errors, last_step=False, ) async def async_step_summary_menu( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Show summary menu and decide to add more entities or to finish the flow.""" self._component_id = None mqtt_device = self._subentry_data[CONF_DEVICE][CONF_NAME] mqtt_items = ", ".join( f"{mqtt_device} {component_data.get(CONF_NAME, '-')} ({component_data[CONF_PLATFORM]})" for component_data in self._subentry_data["components"].values() ) menu_options = [ "entity", "update_entity", ] if len(self._subentry_data["components"]) > 1: menu_options.append("delete_entity") menu_options.extend(["device", "availability"]) menu_options.append( "save_changes" if self._subentry_data != self._get_reconfigure_subentry().data else "export" ) return self.async_show_menu( step_id="summary_menu", menu_options=menu_options, description_placeholders={ "mqtt_device": mqtt_device, "mqtt_items": mqtt_items, }, ) async def async_step_save_changes( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Save the changes made to the subentry.""" entry = self._get_entry() subentry = self._get_reconfigure_subentry() entity_registry = er.async_get(self.hass) # When a component is removed from the MQTT device, # And we save the changes to the subentry, # we need to clean up stale entity registry entries. # The component id is used as a part of the unique id of the entity. for unique_id, platform in [ ( f"{subentry.subentry_id}_{component_id}", subentry.data["components"][component_id][CONF_PLATFORM], ) for component_id in subentry.data["components"] if component_id not in self._subentry_data["components"] ]: if entity_id := entity_registry.async_get_entity_id( platform, DOMAIN, unique_id ): entity_registry.async_remove(entity_id) return self.async_update_and_abort( entry, subentry, data=self._subentry_data, title=self._subentry_data[CONF_DEVICE][CONF_NAME], ) async def async_step_export( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Export the MQTT device config as YAML or discovery payload.""" return self.async_show_menu( step_id="export", menu_options=["export_yaml", "export_discovery"], ) async def async_step_export_yaml( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Export the MQTT device config as YAML.""" if user_input is not None: return await self.async_step_summary_menu() subentry = self._get_reconfigure_subentry() mqtt_yaml_config_base: dict[str, list[dict[str, dict[str, Any]]]] = {DOMAIN: []} mqtt_yaml_config = mqtt_yaml_config_base[DOMAIN] for component_id, component_data in self._subentry_data["components"].items(): component_config: dict[str, Any] = component_data.copy() component_config[CONF_UNIQUE_ID] = f"{subentry.subentry_id}_{component_id}" component_config[CONF_DEVICE] = { key: value for key, value in self._subentry_data["device"].items() if key != "mqtt_settings" } | {"identifiers": [subentry.subentry_id]} platform = component_config.pop(CONF_PLATFORM) component_config.update(self._subentry_data.get("availability", {})) component_config.update( self._subentry_data["device"].get("mqtt_settings", {}).copy() ) mqtt_yaml_config.append({platform: component_config}) yaml_config = yaml.dump(mqtt_yaml_config_base) data_schema = vol.Schema( { vol.Optional("yaml"): TEMPLATE_SELECTOR, } ) data_schema = self.add_suggested_values_to_schema( data_schema=data_schema, suggested_values={"yaml": yaml_config}, ) return self.async_show_form( step_id="export_yaml", last_step=False, data_schema=data_schema, description_placeholders={ "url": "https://www.home-assistant.io/integrations/mqtt/" }, ) async def async_step_export_discovery( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Export the MQTT device config dor MQTT discovery.""" if user_input is not None: return await self.async_step_summary_menu() subentry = self._get_reconfigure_subentry() discovery_topic = f"homeassistant/device/{subentry.subentry_id}/config" discovery_payload: dict[str, Any] = {} discovery_payload.update(self._subentry_data.get("availability", {})) discovery_payload["dev"] = { key: value for key, value in self._subentry_data["device"].items() if key != "mqtt_settings" } | {"identifiers": [subentry.subentry_id]} discovery_payload["o"] = {"name": "MQTT subentry export"} discovery_payload["cmps"] = {} for component_id, component_data in self._subentry_data["components"].items(): component_config: dict[str, Any] = component_data.copy() component_config[CONF_UNIQUE_ID] = f"{subentry.subentry_id}_{component_id}" component_config.update(self._subentry_data.get("availability", {})) component_config.update( self._subentry_data["device"].get("mqtt_settings", {}).copy() ) discovery_payload["cmps"][component_id] = component_config data_schema = vol.Schema( { vol.Optional("discovery_topic"): TEXT_SELECTOR, vol.Optional("discovery_payload"): TEMPLATE_SELECTOR, } ) data_schema = self.add_suggested_values_to_schema( data_schema=data_schema, suggested_values={ "discovery_topic": discovery_topic, "discovery_payload": json.dumps(discovery_payload, indent=2), }, ) return self.async_show_form( step_id="export_discovery", last_step=False, data_schema=data_schema, description_placeholders={ "url": "https://www.home-assistant.io/integrations/mqtt/" }, ) @callback def async_is_pem_data(data: bytes) -> bool: """Return True if data is in PEM format.""" return ( b"-----BEGIN CERTIFICATE-----" in data or b"-----BEGIN PRIVATE KEY-----" in data or b"-----BEGIN RSA PRIVATE KEY-----" in data or b"-----BEGIN ENCRYPTED PRIVATE KEY-----" in data ) class PEMType(IntEnum): """Type of PEM data.""" CERTIFICATE = 1 PRIVATE_KEY = 2 @callback def async_convert_to_pem( data: bytes, pem_type: PEMType, password: str | None = None ) -> str | None: """Convert data to PEM format.""" try: if async_is_pem_data(data): if not password: # Assume unencrypted PEM encoded private key return data.decode(DEFAULT_ENCODING) # Return decrypted PEM encoded private key return ( load_pem_private_key(data, password=password.encode(DEFAULT_ENCODING)) .private_bytes( encoding=Encoding.PEM, format=PrivateFormat.TraditionalOpenSSL, encryption_algorithm=NoEncryption(), ) .decode(DEFAULT_ENCODING) ) # Convert from DER encoding to PEM if pem_type == PEMType.CERTIFICATE: return ( load_der_x509_certificate(data) .public_bytes( encoding=Encoding.PEM, ) .decode(DEFAULT_ENCODING) ) # Assume DER encoded private key pem_key_data: bytes = load_der_private_key( data, password.encode(DEFAULT_ENCODING) if password else None ).private_bytes( encoding=Encoding.PEM, format=PrivateFormat.TraditionalOpenSSL, encryption_algorithm=NoEncryption(), ) return pem_key_data.decode("utf-8") except (TypeError, ValueError, SSLError): _LOGGER.exception("Error converting %s file data to PEM format", pem_type.name) return None async def _get_uploaded_file(hass: HomeAssistant, id: str) -> bytes: """Get file content from uploaded certificate or key file.""" def _proces_uploaded_file() -> bytes: with process_uploaded_file(hass, id) as file_path: return file_path.read_bytes() return await hass.async_add_executor_job(_proces_uploaded_file) def _validate_pki_file( file_id: str | None, pem_data: str | None, errors: dict[str, str], error: str ) -> bool: """Return False if uploaded file could not be converted to PEM format.""" if file_id and not pem_data: errors["base"] = error return False return True async def async_get_broker_settings( # noqa: C901 flow: ConfigFlow | OptionsFlow, fields: OrderedDict[Any, Any], entry_config: MappingProxyType[str, Any] | None, user_input: dict[str, Any] | None, validated_user_input: dict[str, Any], errors: dict[str, str], ) -> bool: """Build the config flow schema to collect the broker settings. Shows advanced options if one or more are configured or when the advanced_broker_options checkbox was selected. Returns True when settings are collected successfully. """ hass = flow.hass advanced_broker_options: bool = False user_input_basic: dict[str, Any] = {} current_config: dict[str, Any] = ( entry_config.copy() if entry_config is not None else {} ) async def _async_validate_broker_settings( config: dict[str, Any], user_input: dict[str, Any], validated_user_input: dict[str, Any], errors: dict[str, str], ) -> bool: """Additional validation on broker settings for better error messages.""" # Get current certificate settings from config entry certificate: str | None = ( "auto" if user_input.get(SET_CA_CERT, "off") == "auto" else config.get(CONF_CERTIFICATE) if user_input.get(SET_CA_CERT, "off") == "custom" else None ) client_certificate: str | None = ( config.get(CONF_CLIENT_CERT) if user_input.get(SET_CLIENT_CERT) else None ) client_key: str | None = ( config.get(CONF_CLIENT_KEY) if user_input.get(SET_CLIENT_CERT) else None ) # Prepare entry update with uploaded files validated_user_input.update(user_input) client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT) client_key_id: str | None = user_input.get(CONF_CLIENT_KEY) # We do not store the private key password in the entry data client_key_password: str | None = validated_user_input.pop( CONF_CLIENT_KEY_PASSWORD, None ) if (client_certificate_id and not client_key_id) or ( not client_certificate_id and client_key_id ): errors["base"] = "invalid_inclusion" return False certificate_id: str | None = user_input.get(CONF_CERTIFICATE) if certificate_id: certificate_data_raw = await _get_uploaded_file(hass, certificate_id) certificate = async_convert_to_pem( certificate_data_raw, PEMType.CERTIFICATE ) if not _validate_pki_file( certificate_id, certificate, errors, "bad_certificate" ): return False # Return to form for file upload CA cert or client cert and key if ( ( not client_certificate and user_input.get(SET_CLIENT_CERT) and not client_certificate_id ) or ( not certificate and user_input.get(SET_CA_CERT, "off") == "custom" and not certificate_id ) or ( user_input.get(CONF_TRANSPORT) == TRANSPORT_WEBSOCKETS and CONF_WS_PATH not in user_input ) ): return False if client_certificate_id: client_certificate_data = await _get_uploaded_file( hass, client_certificate_id ) client_certificate = async_convert_to_pem( client_certificate_data, PEMType.CERTIFICATE ) if not _validate_pki_file( client_certificate_id, client_certificate, errors, "bad_client_cert" ): return False if client_key_id: client_key_data = await _get_uploaded_file(hass, client_key_id) client_key = async_convert_to_pem( client_key_data, PEMType.PRIVATE_KEY, password=client_key_password ) if not _validate_pki_file( client_key_id, client_key, errors, "client_key_error" ): return False certificate_data: dict[str, Any] = {} if certificate: certificate_data[CONF_CERTIFICATE] = certificate if client_certificate: certificate_data[CONF_CLIENT_CERT] = client_certificate certificate_data[CONF_CLIENT_KEY] = client_key validated_user_input.update(certificate_data) await async_create_certificate_temp_files(hass, certificate_data) if error := await hass.async_add_executor_job( check_certicate_chain, ): errors["base"] = error return False if SET_CA_CERT in validated_user_input: del validated_user_input[SET_CA_CERT] if SET_CLIENT_CERT in validated_user_input: del validated_user_input[SET_CLIENT_CERT] if validated_user_input.get(CONF_TRANSPORT, TRANSPORT_TCP) == TRANSPORT_TCP: if CONF_WS_PATH in validated_user_input: del validated_user_input[CONF_WS_PATH] if CONF_WS_HEADERS in validated_user_input: del validated_user_input[CONF_WS_HEADERS] return True try: validated_user_input[CONF_WS_HEADERS] = json_loads( validated_user_input.get(CONF_WS_HEADERS, "{}") ) schema = vol.Schema({cv.string: cv.template}) schema(validated_user_input[CONF_WS_HEADERS]) except (*JSON_DECODE_EXCEPTIONS, vol.MultipleInvalid): errors["base"] = "bad_ws_headers" return False return True if user_input: user_input_basic = user_input.copy() advanced_broker_options = user_input_basic.get(ADVANCED_OPTIONS, False) if ADVANCED_OPTIONS not in user_input or advanced_broker_options is False: if await _async_validate_broker_settings( current_config, user_input_basic, validated_user_input, errors, ): return True # Get defaults settings from previous post current_broker = user_input_basic.get(CONF_BROKER) current_port = user_input_basic.get(CONF_PORT, DEFAULT_PORT) current_user = user_input_basic.get(CONF_USERNAME) current_pass = user_input_basic.get(CONF_PASSWORD) else: # Get default settings from entry (if any) current_broker = current_config.get(CONF_BROKER) current_port = current_config.get(CONF_PORT, DEFAULT_PORT) current_user = current_config.get(CONF_USERNAME) # Return the sentinel password to avoid exposure current_entry_pass = current_config.get(CONF_PASSWORD) current_pass = PWD_NOT_CHANGED if current_entry_pass else None # Treat the previous post as an update of the current settings # (if there was a basic broker setup step) current_config.update(user_input_basic) # Get default settings for advanced broker options current_client_id = current_config.get(CONF_CLIENT_ID) current_keepalive = current_config.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE) current_ca_certificate = current_config.get(CONF_CERTIFICATE) current_client_certificate = current_config.get(CONF_CLIENT_CERT) current_client_key = current_config.get(CONF_CLIENT_KEY) current_tls_insecure = current_config.get(CONF_TLS_INSECURE, False) current_protocol = current_config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) current_transport = current_config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT) current_ws_path = current_config.get(CONF_WS_PATH, DEFAULT_WS_PATH) current_ws_headers = ( json_dumps(current_config.get(CONF_WS_HEADERS)) if CONF_WS_HEADERS in current_config else None ) advanced_broker_options |= bool( current_client_id or current_keepalive != DEFAULT_KEEPALIVE or current_ca_certificate or current_client_certificate or current_client_key or current_tls_insecure or current_protocol != DEFAULT_PROTOCOL or current_config.get(SET_CA_CERT, "off") != "off" or current_config.get(SET_CLIENT_CERT) or current_transport == TRANSPORT_WEBSOCKETS ) # Build form fields[vol.Required(CONF_BROKER, default=current_broker)] = TEXT_SELECTOR fields[vol.Required(CONF_PORT, default=current_port)] = PORT_SELECTOR fields[ vol.Optional( CONF_USERNAME, description={"suggested_value": current_user}, ) ] = TEXT_SELECTOR fields[ vol.Optional( CONF_PASSWORD, description={"suggested_value": current_pass}, ) ] = PASSWORD_SELECTOR # show advanced options checkbox if requested and # advanced options are enabled # or when the defaults of advanced options are overridden if not advanced_broker_options: if not flow.show_advanced_options: return False fields[ vol.Optional( ADVANCED_OPTIONS, ) ] = BOOLEAN_SELECTOR return False fields[ vol.Optional( CONF_CLIENT_ID, description={"suggested_value": current_client_id}, ) ] = TEXT_SELECTOR fields[ vol.Optional( CONF_KEEPALIVE, description={"suggested_value": current_keepalive}, ) ] = KEEPALIVE_SELECTOR fields[ vol.Optional( SET_CLIENT_CERT, default=current_client_certificate is not None or current_config.get(SET_CLIENT_CERT) is True, ) ] = BOOLEAN_SELECTOR if ( current_client_certificate is not None or current_config.get(SET_CLIENT_CERT) is True ): fields[ vol.Optional( CONF_CLIENT_CERT, description={"suggested_value": user_input_basic.get(CONF_CLIENT_CERT)}, ) ] = CERT_UPLOAD_SELECTOR fields[ vol.Optional( CONF_CLIENT_KEY, description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)}, ) ] = KEY_UPLOAD_SELECTOR fields[ vol.Optional( CONF_CLIENT_KEY_PASSWORD, description={ "suggested_value": user_input_basic.get(CONF_CLIENT_KEY_PASSWORD) }, ) ] = PASSWORD_SELECTOR verification_mode = current_config.get(SET_CA_CERT) or ( "off" if current_ca_certificate is None else "auto" if current_ca_certificate == "auto" else "custom" ) fields[ vol.Optional( SET_CA_CERT, default=verification_mode, ) ] = BROKER_VERIFICATION_SELECTOR if current_ca_certificate is not None or verification_mode == "custom": fields[ vol.Optional( CONF_CERTIFICATE, user_input_basic.get(CONF_CERTIFICATE), ) ] = CA_CERT_UPLOAD_SELECTOR fields[ vol.Optional( CONF_TLS_INSECURE, description={"suggested_value": current_tls_insecure}, ) ] = BOOLEAN_SELECTOR fields[ vol.Optional( CONF_PROTOCOL, description={"suggested_value": current_protocol}, ) ] = PROTOCOL_SELECTOR fields[ vol.Optional( CONF_TRANSPORT, description={"suggested_value": current_transport}, ) ] = TRANSPORT_SELECTOR if current_transport == TRANSPORT_WEBSOCKETS: fields[ vol.Optional(CONF_WS_PATH, description={"suggested_value": current_ws_path}) ] = TEXT_SELECTOR fields[ vol.Optional( CONF_WS_HEADERS, description={"suggested_value": current_ws_headers} ) ] = WS_HEADERS_SELECTOR # Show form return False def try_connection( user_input: dict[str, Any], ) -> bool: """Test if we can connect to an MQTT broker.""" # We don't import on the top because some integrations # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel mqtt_client_setup = MqttClientSetup(user_input) mqtt_client_setup.setup() client = mqtt_client_setup.client result: queue.Queue[bool] = queue.Queue(maxsize=1) def on_connect( _mqttc: mqtt.Client, _userdata: None, _connect_flags: mqtt.ConnectFlags, reason_code: mqtt.ReasonCode, _properties: mqtt.Properties | None = None, ) -> None: """Handle connection result.""" result.put(not reason_code.is_failure) client.on_connect = on_connect client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT]) client.loop_start() try: return result.get(timeout=MQTT_TIMEOUT) except queue.Empty: return False finally: client.disconnect() client.loop_stop() def check_certicate_chain() -> str | None: """Check the MQTT certificates.""" if client_certificate := get_file_path(CONF_CLIENT_CERT): try: with open(client_certificate, "rb") as client_certificate_file: load_pem_x509_certificate(client_certificate_file.read()) except ValueError: return "bad_client_cert" # Check we can serialize the private key file if private_key := get_file_path(CONF_CLIENT_KEY): try: with open(private_key, "rb") as client_key_file: load_pem_private_key(client_key_file.read(), password=None) except (TypeError, ValueError): return "client_key_error" # Check the certificate chain context = SSLContext(PROTOCOL_TLS_CLIENT) if client_certificate and private_key: try: context.load_cert_chain(client_certificate, private_key) except SSLError: return "bad_client_cert_key" # try to load the custom CA file if (ca_cert := get_file_path(CONF_CERTIFICATE)) is None: return None try: context.load_verify_locations(ca_cert) except SSLError: return "bad_certificate" return None