Pieter Viljoen d65392a374
ConfigSubEntryFlow _get_reconfigure_entry() -> _get_entry() (#141017)
* ConfigSubEntryFlow _get_reconfigure_entry() -> _get_entry()

* Update MQTT test

* Fix test_config_entries

* Minimize changes to keep existing tests working

* Re-revert and update negative test instead
2025-03-24 09:24:43 +01:00

1682 lines
58 KiB
Python

"""Config flow for MQTT."""
from __future__ import annotations
import asyncio
from collections import OrderedDict
from collections.abc import Callable, Mapping
from copy import deepcopy
from dataclasses import dataclass
from enum import IntEnum
import logging
import queue
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
load_der_private_key,
load_pem_private_key,
)
from cryptography.x509 import load_der_x509_certificate, load_pem_x509_certificate
import voluptuous as vol
from homeassistant.components.file_upload import process_uploaded_file
from homeassistant.components.hassio import AddonError, AddonManager, AddonState
from homeassistant.config_entries import (
SOURCE_RECONFIGURE,
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
ConfigSubentryFlow,
OptionsFlow,
SubentryFlowResult,
)
from homeassistant.const import (
ATTR_CONFIGURATION_URL,
ATTR_HW_VERSION,
ATTR_MODEL,
ATTR_MODEL_ID,
ATTR_NAME,
ATTR_SW_VERSION,
CONF_CLIENT_ID,
CONF_DEVICE,
CONF_DISCOVERY,
CONF_HOST,
CONF_NAME,
CONF_PASSWORD,
CONF_PAYLOAD,
CONF_PLATFORM,
CONF_PORT,
CONF_PROTOCOL,
CONF_USERNAME,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.hassio import is_hassio
from homeassistant.helpers.json import json_dumps
from homeassistant.helpers.selector import (
BooleanSelector,
FileSelector,
FileSelectorConfig,
NumberSelector,
NumberSelectorConfig,
NumberSelectorMode,
SelectOptionDict,
Selector,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TemplateSelector,
TemplateSelectorConfig,
TextSelector,
TextSelectorConfig,
TextSelectorType,
)
from homeassistant.helpers.service_info.hassio import HassioServiceInfo
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
from .addon import get_addon_manager
from .client import MqttClientSetup
from .const import (
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_AVAILABILITY_TEMPLATE,
CONF_AVAILABILITY_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_CERTIFICATE,
CONF_CLIENT_CERT,
CONF_CLIENT_KEY,
CONF_COMMAND_TEMPLATE,
CONF_COMMAND_TOPIC,
CONF_DISCOVERY_PREFIX,
CONF_ENTITY_PICTURE,
CONF_KEEPALIVE,
CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE,
CONF_QOS,
CONF_RETAIN,
CONF_TLS_INSECURE,
CONF_TRANSPORT,
CONF_WILL_MESSAGE,
CONF_WS_HEADERS,
CONF_WS_PATH,
CONFIG_ENTRY_MINOR_VERSION,
CONFIG_ENTRY_VERSION,
DEFAULT_BIRTH,
DEFAULT_DISCOVERY,
DEFAULT_ENCODING,
DEFAULT_KEEPALIVE,
DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE,
DEFAULT_PORT,
DEFAULT_PREFIX,
DEFAULT_PROTOCOL,
DEFAULT_TRANSPORT,
DEFAULT_WILL,
DEFAULT_WS_PATH,
DOMAIN,
SUPPORTED_PROTOCOLS,
TRANSPORT_TCP,
TRANSPORT_WEBSOCKETS,
Platform,
)
from .models import MqttAvailabilityData, MqttDeviceData, MqttSubentryData
from .util import (
async_create_certificate_temp_files,
get_file_path,
valid_birth_will,
valid_publish_topic,
valid_qos_schema,
valid_subscribe_topic,
valid_subscribe_topic_template,
)
_LOGGER = logging.getLogger(__name__)
ADDON_SETUP_TIMEOUT = 5
ADDON_SETUP_TIMEOUT_ROUNDS = 5
CONF_CLIENT_KEY_PASSWORD = "client_key_password"
MQTT_TIMEOUT = 5
ADVANCED_OPTIONS = "advanced_options"
SET_CA_CERT = "set_ca_cert"
SET_CLIENT_CERT = "set_client_cert"
BOOLEAN_SELECTOR = BooleanSelector()
TEXT_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT))
PUBLISH_TOPIC_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.TEXT))
PORT_SELECTOR = vol.All(
NumberSelector(NumberSelectorConfig(mode=NumberSelectorMode.BOX, min=1, max=65535)),
vol.Coerce(int),
)
PASSWORD_SELECTOR = TextSelector(TextSelectorConfig(type=TextSelectorType.PASSWORD))
QOS_SELECTOR = NumberSelector(
NumberSelectorConfig(mode=NumberSelectorMode.BOX, min=0, max=2)
)
QOS_DATA_SCHEMA = vol.All(QOS_SELECTOR, valid_qos_schema)
KEEPALIVE_SELECTOR = vol.All(
NumberSelector(
NumberSelectorConfig(
mode=NumberSelectorMode.BOX, min=15, step="any", unit_of_measurement="sec"
)
),
vol.Coerce(int),
)
PROTOCOL_SELECTOR = SelectSelector(
SelectSelectorConfig(
options=SUPPORTED_PROTOCOLS,
mode=SelectSelectorMode.DROPDOWN,
)
)
SUPPORTED_TRANSPORTS = [
SelectOptionDict(value=TRANSPORT_TCP, label="TCP"),
SelectOptionDict(value=TRANSPORT_WEBSOCKETS, label="WebSocket"),
]
TRANSPORT_SELECTOR = SelectSelector(
SelectSelectorConfig(
options=SUPPORTED_TRANSPORTS,
mode=SelectSelectorMode.DROPDOWN,
)
)
WS_HEADERS_SELECTOR = TextSelector(
TextSelectorConfig(type=TextSelectorType.TEXT, multiline=True)
)
CA_VERIFICATION_MODES = [
"off",
"auto",
"custom",
]
BROKER_VERIFICATION_SELECTOR = SelectSelector(
SelectSelectorConfig(
options=CA_VERIFICATION_MODES,
mode=SelectSelectorMode.DROPDOWN,
translation_key=SET_CA_CERT,
)
)
# mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html
CA_CERT_UPLOAD_SELECTOR = FileSelector(
FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-ca-cert")
)
CERT_UPLOAD_SELECTOR = FileSelector(
FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-user-cert")
)
KEY_UPLOAD_SELECTOR = FileSelector(
FileSelectorConfig(accept=".pem,.key,.der,.pk8,application/pkcs8")
)
# Subentry selectors
SUBENTRY_PLATFORMS = [Platform.NOTIFY]
SUBENTRY_PLATFORM_SELECTOR = SelectSelector(
SelectSelectorConfig(
options=[platform.value for platform in SUBENTRY_PLATFORMS],
mode=SelectSelectorMode.DROPDOWN,
translation_key=CONF_PLATFORM,
)
)
TEMPLATE_SELECTOR = TemplateSelector(TemplateSelectorConfig())
SUBENTRY_AVAILABILITY_SCHEMA = vol.Schema(
{
vol.Optional(CONF_AVAILABILITY_TOPIC): TEXT_SELECTOR,
vol.Optional(CONF_AVAILABILITY_TEMPLATE): TEMPLATE_SELECTOR,
vol.Optional(
CONF_PAYLOAD_AVAILABLE, default=DEFAULT_PAYLOAD_AVAILABLE
): TEXT_SELECTOR,
vol.Optional(
CONF_PAYLOAD_NOT_AVAILABLE, default=DEFAULT_PAYLOAD_NOT_AVAILABLE
): TEXT_SELECTOR,
}
)
@dataclass(frozen=True)
class PlatformField:
"""Stores a platform config field schema, required flag and validator."""
selector: Selector
required: bool
validator: Callable[..., Any]
error: str | None = None
default: str | int | vol.Undefined = vol.UNDEFINED
exclude_from_reconfig: bool = False
COMMON_ENTITY_FIELDS = {
CONF_PLATFORM: PlatformField(
SUBENTRY_PLATFORM_SELECTOR, True, str, exclude_from_reconfig=True
),
CONF_NAME: PlatformField(TEXT_SELECTOR, False, str, exclude_from_reconfig=True),
CONF_ENTITY_PICTURE: PlatformField(TEXT_SELECTOR, False, cv.url, "invalid_url"),
}
COMMON_MQTT_FIELDS = {
CONF_QOS: PlatformField(QOS_SELECTOR, False, valid_qos_schema, default=0),
CONF_RETAIN: PlatformField(BOOLEAN_SELECTOR, False, bool),
}
PLATFORM_MQTT_FIELDS = {
Platform.NOTIFY.value: {
CONF_COMMAND_TOPIC: PlatformField(
TEXT_SELECTOR, True, valid_publish_topic, "invalid_publish_topic"
),
CONF_COMMAND_TEMPLATE: PlatformField(
TEMPLATE_SELECTOR, False, cv.template, "invalid_template"
),
},
}
MQTT_DEVICE_SCHEMA = vol.Schema(
{
vol.Required(ATTR_NAME): TEXT_SELECTOR,
vol.Optional(ATTR_SW_VERSION): TEXT_SELECTOR,
vol.Optional(ATTR_HW_VERSION): TEXT_SELECTOR,
vol.Optional(ATTR_MODEL): TEXT_SELECTOR,
vol.Optional(ATTR_MODEL_ID): TEXT_SELECTOR,
vol.Optional(ATTR_CONFIGURATION_URL): TEXT_SELECTOR,
}
)
REAUTH_SCHEMA = vol.Schema(
{
vol.Required(CONF_USERNAME): TEXT_SELECTOR,
vol.Required(CONF_PASSWORD): PASSWORD_SELECTOR,
}
)
PWD_NOT_CHANGED = "__**password_not_changed**__"
@callback
def update_password_from_user_input(
entry_password: str | None, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Update the password if the entry has been updated.
As we want to avoid reflecting the stored password in the UI,
we replace the suggested value in the UI with a sentitel,
and we change it back here if it was changed.
"""
substituted_used_data = dict(user_input)
# Take out the password submitted
user_password: str | None = substituted_used_data.pop(CONF_PASSWORD, None)
# Only add the password if it has changed.
# If the sentinel password is submitted, we replace that with our current
# password from the config entry data.
password_changed = user_password is not None and user_password != PWD_NOT_CHANGED
password = user_password if password_changed else entry_password
if password is not None:
substituted_used_data[CONF_PASSWORD] = password
return substituted_used_data
@callback
def validate_field(
field: str,
validator: Callable[..., Any],
user_input: dict[str, Any] | None,
errors: dict[str, str],
error: str,
) -> None:
"""Validate a single field."""
if user_input is None or field not in user_input:
return
try:
validator(user_input[field])
except (ValueError, vol.Invalid):
errors[field] = error
@callback
def validate_user_input(
user_input: dict[str, Any],
data_schema_fields: dict[str, PlatformField],
errors: dict[str, str],
) -> None:
"""Validate user input."""
for field, value in user_input.items():
validator = data_schema_fields[field].validator
try:
validator(value)
except (ValueError, vol.Invalid):
errors[field] = data_schema_fields[field].error or "invalid_input"
@callback
def data_schema_from_fields(
data_schema_fields: dict[str, PlatformField],
reconfig: bool,
) -> vol.Schema:
"""Generate data schema from platform fields."""
return vol.Schema(
{
vol.Required(field_name, default=field_details.default)
if field_details.required
else vol.Optional(
field_name, default=field_details.default
): field_details.selector
for field_name, field_details in data_schema_fields.items()
if not field_details.exclude_from_reconfig or not reconfig
}
)
class FlowHandler(ConfigFlow, domain=DOMAIN):
"""Handle a config flow."""
# Can be bumped to version 2.1 with HA Core 2026.1.0
VERSION = CONFIG_ENTRY_VERSION # 1
MINOR_VERSION = CONFIG_ENTRY_MINOR_VERSION # 2
_hassio_discovery: dict[str, Any] | None = None
_addon_manager: AddonManager
def __init__(self) -> None:
"""Set up flow instance."""
self.install_task: asyncio.Task | None = None
self.start_task: asyncio.Task | None = None
@classmethod
@callback
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this handler."""
return {CONF_DEVICE: MQTTSubentryFlowHandler}
@staticmethod
@callback
def async_get_options_flow(
config_entry: ConfigEntry,
) -> MQTTOptionsFlowHandler:
"""Get the options flow for this handler."""
return MQTTOptionsFlowHandler()
async def _async_install_addon(self) -> None:
"""Install the Mosquitto Mqtt broker add-on."""
addon_manager: AddonManager = get_addon_manager(self.hass)
await addon_manager.async_schedule_install_addon()
async def async_step_install_failed(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Add-on installation failed."""
return self.async_abort(
reason="addon_install_failed",
description_placeholders={"addon": self._addon_manager.addon_name},
)
async def async_step_install_addon(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Install Mosquitto Broker add-on."""
if self.install_task is None:
self.install_task = self.hass.async_create_task(self._async_install_addon())
if not self.install_task.done():
return self.async_show_progress(
step_id="install_addon",
progress_action="install_addon",
progress_task=self.install_task,
)
try:
await self.install_task
except AddonError as err:
_LOGGER.error(err)
return self.async_show_progress_done(next_step_id="install_failed")
finally:
self.install_task = None
return self.async_show_progress_done(next_step_id="start_addon")
async def async_step_start_failed(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Add-on start failed."""
return self.async_abort(
reason="addon_start_failed",
description_placeholders={"addon": self._addon_manager.addon_name},
)
async def async_step_start_addon(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Start Mosquitto Broker add-on."""
if not self.start_task:
self.start_task = self.hass.async_create_task(self._async_start_addon())
if not self.start_task.done():
return self.async_show_progress(
step_id="start_addon",
progress_action="start_addon",
progress_task=self.start_task,
)
try:
await self.start_task
except AddonError as err:
_LOGGER.error(err)
return self.async_show_progress_done(next_step_id="start_failed")
finally:
self.start_task = None
return self.async_show_progress_done(next_step_id="setup_entry_from_discovery")
async def _async_get_config_and_try(self) -> dict[str, Any] | None:
"""Get the MQTT add-on discovery info and try the connection."""
if self._hassio_discovery is not None:
return self._hassio_discovery
addon_manager: AddonManager = get_addon_manager(self.hass)
try:
addon_discovery_config = (
await addon_manager.async_get_addon_discovery_info()
)
config: dict[str, Any] = {
CONF_BROKER: addon_discovery_config[CONF_HOST],
CONF_PORT: addon_discovery_config[CONF_PORT],
CONF_USERNAME: addon_discovery_config.get(CONF_USERNAME),
CONF_PASSWORD: addon_discovery_config.get(CONF_PASSWORD),
CONF_DISCOVERY: DEFAULT_DISCOVERY,
}
except AddonError:
# We do not have discovery information yet
return None
if await self.hass.async_add_executor_job(
try_connection,
config,
):
self._hassio_discovery = config
return config
return None
async def _async_start_addon(self) -> None:
"""Start the Mosquitto Broker add-on."""
addon_manager: AddonManager = get_addon_manager(self.hass)
await addon_manager.async_schedule_start_addon()
# Sleep some seconds to let the add-on start properly before connecting.
for _ in range(ADDON_SETUP_TIMEOUT_ROUNDS):
await asyncio.sleep(ADDON_SETUP_TIMEOUT)
# Finish setup using discovery info to test the connection
if await self._async_get_config_and_try():
break
else:
raise AddonError(
translation_domain=DOMAIN,
translation_key="addon_start_failed",
translation_placeholders={"addon": addon_manager.addon_name},
)
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle a flow initialized by the user."""
if is_hassio(self.hass):
# Offer to set up broker add-on if supervisor is available
self._addon_manager = get_addon_manager(self.hass)
return self.async_show_menu(
step_id="user",
menu_options=["addon", "broker"],
description_placeholders={"addon": self._addon_manager.addon_name},
)
# Start up a flow for manual setup
return await self.async_step_broker()
async def async_step_setup_entry_from_discovery(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Set up mqtt entry from discovery info."""
if (config := await self._async_get_config_and_try()) is not None:
return self.async_create_entry(
title=self._addon_manager.addon_name,
data=config,
)
raise AbortFlow(
"addon_connection_failed",
description_placeholders={"addon": self._addon_manager.addon_name},
)
async def async_step_addon(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Install and start MQTT broker add-on."""
addon_manager = self._addon_manager
try:
addon_info = await addon_manager.async_get_addon_info()
except AddonError as err:
raise AbortFlow(
"addon_info_failed",
description_placeholders={"addon": self._addon_manager.addon_name},
) from err
if addon_info.state == AddonState.RUNNING:
# Finish setup using discovery info
return await self.async_step_setup_entry_from_discovery()
if addon_info.state == AddonState.NOT_RUNNING:
return await self.async_step_start_addon()
# Install the add-on and start it
return await self.async_step_install_addon()
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Handle re-authentication with MQTT broker."""
if is_hassio(self.hass):
# Check if entry setup matches the add-on discovery config
addon_manager = get_addon_manager(self.hass)
try:
addon_discovery_config = (
await addon_manager.async_get_addon_discovery_info()
)
except AddonError:
# Follow manual flow if we have an error
pass
else:
# Check if the addon secrets need to be renewed.
# This will repair the config entry,
# in case the official Mosquitto Broker addon was re-installed.
if (
entry_data[CONF_BROKER] == addon_discovery_config[CONF_HOST]
and entry_data[CONF_PORT] == addon_discovery_config[CONF_PORT]
and entry_data.get(CONF_USERNAME)
== (username := addon_discovery_config.get(CONF_USERNAME))
and entry_data.get(CONF_PASSWORD)
!= (password := addon_discovery_config.get(CONF_PASSWORD))
):
_LOGGER.info(
"Executing autorecovery %s add-on secrets",
addon_manager.addon_name,
)
return await self.async_step_reauth_confirm(
user_input={CONF_USERNAME: username, CONF_PASSWORD: password}
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm re-authentication with MQTT broker."""
errors: dict[str, str] = {}
reauth_entry = self._get_reauth_entry()
if user_input:
substituted_used_data = update_password_from_user_input(
reauth_entry.data.get(CONF_PASSWORD), user_input
)
new_entry_data = {**reauth_entry.data, **substituted_used_data}
if await self.hass.async_add_executor_job(
try_connection,
new_entry_data,
):
return self.async_update_reload_and_abort(
reauth_entry, data=new_entry_data
)
errors["base"] = "invalid_auth"
schema = self.add_suggested_values_to_schema(
REAUTH_SCHEMA,
{
CONF_USERNAME: reauth_entry.data.get(CONF_USERNAME),
CONF_PASSWORD: PWD_NOT_CHANGED,
},
)
return self.async_show_form(
step_id="reauth_confirm",
data_schema=schema,
errors=errors,
)
async def async_step_broker(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm the setup."""
errors: dict[str, str] = {}
fields: OrderedDict[Any, Any] = OrderedDict()
validated_user_input: dict[str, Any] = {}
if is_reconfigure := (self.source == SOURCE_RECONFIGURE):
reconfigure_entry = self._get_reconfigure_entry()
if await async_get_broker_settings(
self,
fields,
reconfigure_entry.data if is_reconfigure else None,
user_input,
validated_user_input,
errors,
):
if is_reconfigure:
validated_user_input = update_password_from_user_input(
reconfigure_entry.data.get(CONF_PASSWORD), validated_user_input
)
can_connect = await self.hass.async_add_executor_job(
try_connection,
validated_user_input,
)
if can_connect:
if is_reconfigure:
return self.async_update_reload_and_abort(
reconfigure_entry,
data=validated_user_input,
)
return self.async_create_entry(
title=validated_user_input[CONF_BROKER],
data=validated_user_input,
)
errors["base"] = "cannot_connect"
return self.async_show_form(
step_id="broker", data_schema=vol.Schema(fields), errors=errors
)
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle a reconfiguration flow initialized by the user."""
return await self.async_step_broker()
async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Receive a Hass.io discovery or process setup after addon install."""
await self._async_handle_discovery_without_unique_id()
self._hassio_discovery = discovery_info.config
return await self.async_step_hassio_confirm()
async def async_step_hassio_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm a Hass.io discovery."""
errors: dict[str, str] = {}
if TYPE_CHECKING:
assert self._hassio_discovery
if user_input is not None:
data: dict[str, Any] = self._hassio_discovery.copy()
data[CONF_BROKER] = data.pop(CONF_HOST)
can_connect = await self.hass.async_add_executor_job(
try_connection,
data,
)
if can_connect:
return self.async_create_entry(
title=data["addon"],
data={
CONF_BROKER: data[CONF_BROKER],
CONF_PORT: data[CONF_PORT],
CONF_USERNAME: data.get(CONF_USERNAME),
CONF_PASSWORD: data.get(CONF_PASSWORD),
CONF_DISCOVERY: DEFAULT_DISCOVERY,
},
)
errors["base"] = "cannot_connect"
return self.async_show_form(
step_id="hassio_confirm",
description_placeholders={"addon": self._hassio_discovery["addon"]},
errors=errors,
)
class MQTTOptionsFlowHandler(OptionsFlow):
"""Handle MQTT options."""
async def async_step_init(self, user_input: None = None) -> ConfigFlowResult:
"""Manage the MQTT options."""
return await self.async_step_options()
async def async_step_options(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the MQTT options."""
errors = {}
options_config: dict[str, Any] = dict(self.config_entry.options)
bad_input: bool = False
def _birth_will(birt_or_will: str) -> dict[str, Any]:
"""Return the user input for birth or will."""
if TYPE_CHECKING:
assert user_input
return {
ATTR_TOPIC: user_input[f"{birt_or_will}_topic"],
ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""),
ATTR_QOS: user_input[f"{birt_or_will}_qos"],
ATTR_RETAIN: user_input[f"{birt_or_will}_retain"],
}
def _validate(
field: str,
values: dict[str, Any],
error_code: str,
schema: Callable[[Any], Any],
) -> None:
"""Validate the user input."""
nonlocal bad_input
try:
option_values = schema(values)
options_config[field] = option_values
except vol.Invalid:
errors["base"] = error_code
bad_input = True
if user_input is not None:
# validate input
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
_validate(
CONF_DISCOVERY_PREFIX,
user_input[CONF_DISCOVERY_PREFIX],
"bad_discovery_prefix",
valid_publish_topic,
)
if "birth_topic" in user_input:
_validate(
CONF_BIRTH_MESSAGE,
_birth_will("birth"),
"bad_birth",
valid_birth_will,
)
if not user_input["birth_enable"]:
options_config[CONF_BIRTH_MESSAGE] = {}
if "will_topic" in user_input:
_validate(
CONF_WILL_MESSAGE,
_birth_will("will"),
"bad_will",
valid_birth_will,
)
if not user_input["will_enable"]:
options_config[CONF_WILL_MESSAGE] = {}
if not bad_input:
return self.async_create_entry(data=options_config)
birth = {
**DEFAULT_BIRTH,
**options_config.get(CONF_BIRTH_MESSAGE, {}),
}
will = {
**DEFAULT_WILL,
**options_config.get(CONF_WILL_MESSAGE, {}),
}
discovery = options_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY)
discovery_prefix = options_config.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX)
# build form
fields: OrderedDict[vol.Marker, Any] = OrderedDict()
fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = BOOLEAN_SELECTOR
fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=discovery_prefix)] = (
PUBLISH_TOPIC_SELECTOR
)
# Birth message is disabled if CONF_BIRTH_MESSAGE = {}
fields[
vol.Optional(
"birth_enable",
default=CONF_BIRTH_MESSAGE not in options_config
or options_config[CONF_BIRTH_MESSAGE] != {},
)
] = BOOLEAN_SELECTOR
fields[
vol.Optional(
"birth_topic", description={"suggested_value": birth[ATTR_TOPIC]}
)
] = PUBLISH_TOPIC_SELECTOR
fields[
vol.Optional(
"birth_payload", description={"suggested_value": birth[CONF_PAYLOAD]}
)
] = TEXT_SELECTOR
fields[vol.Optional("birth_qos", default=birth[ATTR_QOS])] = QOS_DATA_SCHEMA
fields[vol.Optional("birth_retain", default=birth[ATTR_RETAIN])] = (
BOOLEAN_SELECTOR
)
# Will message is disabled if CONF_WILL_MESSAGE = {}
fields[
vol.Optional(
"will_enable",
default=CONF_WILL_MESSAGE not in options_config
or options_config[CONF_WILL_MESSAGE] != {},
)
] = BOOLEAN_SELECTOR
fields[
vol.Optional(
"will_topic", description={"suggested_value": will[ATTR_TOPIC]}
)
] = PUBLISH_TOPIC_SELECTOR
fields[
vol.Optional(
"will_payload", description={"suggested_value": will[CONF_PAYLOAD]}
)
] = TEXT_SELECTOR
fields[vol.Optional("will_qos", default=will[ATTR_QOS])] = QOS_DATA_SCHEMA
fields[vol.Optional("will_retain", default=will[ATTR_RETAIN])] = (
BOOLEAN_SELECTOR
)
return self.async_show_form(
step_id="options",
data_schema=vol.Schema(fields),
errors=errors,
last_step=True,
)
class MQTTSubentryFlowHandler(ConfigSubentryFlow):
"""Handle MQTT subentry flow."""
_subentry_data: MqttSubentryData
_component_id: str | None = None
@callback
def update_component_fields(
self, data_schema: vol.Schema, user_input: dict[str, Any]
) -> None:
"""Update the componment fields."""
if TYPE_CHECKING:
assert self._component_id is not None
component_data = self._subentry_data["components"][self._component_id]
# Remove the fields from the component data if they are not in the user input
for field in [
form_field
for form_field in data_schema.schema
if form_field in component_data and form_field not in user_input
]:
component_data.pop(field)
component_data.update(user_input)
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Add a subentry."""
self._subentry_data = MqttSubentryData(device=MqttDeviceData(), components={})
return await self.async_step_device()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Reconfigure a subentry."""
reconfigure_subentry = self._get_reconfigure_subentry()
self._subentry_data = cast(
MqttSubentryData, deepcopy(dict(reconfigure_subentry.data))
)
return await self.async_step_summary_menu()
async def async_step_device(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Add a new MQTT device."""
errors: dict[str, str] = {}
validate_field("configuration_url", cv.url, user_input, errors, "invalid_url")
if not errors and user_input is not None:
self._subentry_data[CONF_DEVICE] = cast(MqttDeviceData, user_input)
if self.source == SOURCE_RECONFIGURE:
return await self.async_step_summary_menu()
return await self.async_step_entity()
data_schema = self.add_suggested_values_to_schema(
MQTT_DEVICE_SCHEMA,
self._subentry_data[CONF_DEVICE] if user_input is None else user_input,
)
return self.async_show_form(
step_id=CONF_DEVICE,
data_schema=data_schema,
errors=errors,
last_step=False,
)
async def async_step_entity(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Add or edit an mqtt entity."""
errors: dict[str, str] = {}
data_schema_fields = COMMON_ENTITY_FIELDS
entity_name_label: str = ""
platform_label: str = ""
if reconfig := (self._component_id is not None):
name: str | None = self._subentry_data["components"][
self._component_id
].get(CONF_NAME)
platform_label = f"{self._subentry_data['components'][self._component_id][CONF_PLATFORM]} "
entity_name_label = f" ({name})" if name is not None else ""
data_schema = data_schema_from_fields(data_schema_fields, reconfig=reconfig)
if user_input is not None:
validate_user_input(user_input, data_schema_fields, errors)
if not errors:
if self._component_id is None:
self._component_id = uuid4().hex
self._subentry_data["components"].setdefault(self._component_id, {})
self.update_component_fields(data_schema, user_input)
return await self.async_step_mqtt_platform_config()
data_schema = self.add_suggested_values_to_schema(data_schema, user_input)
elif self.source == SOURCE_RECONFIGURE and self._component_id is not None:
data_schema = self.add_suggested_values_to_schema(
data_schema, self._subentry_data["components"][self._component_id]
)
device_name = self._subentry_data[CONF_DEVICE][CONF_NAME]
return self.async_show_form(
step_id="entity",
data_schema=data_schema,
description_placeholders={
"mqtt_device": device_name,
"entity_name_label": entity_name_label,
"platform_label": platform_label,
},
errors=errors,
last_step=False,
)
def _show_update_or_delete_form(self, step_id: str) -> SubentryFlowResult:
"""Help selecting an entity to update or delete."""
device_name = self._subentry_data[CONF_DEVICE][CONF_NAME]
entities = [
SelectOptionDict(
value=key, label=f"{device_name} {component.get(CONF_NAME, '-')}"
)
for key, component in self._subentry_data["components"].items()
]
data_schema = vol.Schema(
{
vol.Required("component"): SelectSelector(
SelectSelectorConfig(
options=entities,
mode=SelectSelectorMode.LIST,
)
)
}
)
return self.async_show_form(
step_id=step_id, data_schema=data_schema, last_step=False
)
async def async_step_update_entity(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Select the entity to update."""
if user_input:
self._component_id = user_input["component"]
return await self.async_step_entity()
if len(self._subentry_data["components"]) == 1:
# Return first key
self._component_id = next(iter(self._subentry_data["components"]))
return await self.async_step_entity()
return self._show_update_or_delete_form("update_entity")
async def async_step_delete_entity(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Select the entity to delete."""
if user_input:
del self._subentry_data["components"][user_input["component"]]
return await self.async_step_summary_menu()
return self._show_update_or_delete_form("delete_entity")
async def async_step_mqtt_platform_config(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Configure entity platform MQTT details."""
errors: dict[str, str] = {}
if TYPE_CHECKING:
assert self._component_id is not None
platform = self._subentry_data["components"][self._component_id][CONF_PLATFORM]
data_schema_fields = PLATFORM_MQTT_FIELDS[platform] | COMMON_MQTT_FIELDS
data_schema = data_schema_from_fields(
data_schema_fields, reconfig=self._component_id is not None
)
if user_input is not None:
# Test entity fields against the validator
validate_user_input(user_input, data_schema_fields, errors)
if not errors:
self.update_component_fields(data_schema, user_input)
self._component_id = None
if self.source == SOURCE_RECONFIGURE:
return await self.async_step_summary_menu()
return self._async_create_subentry()
data_schema = self.add_suggested_values_to_schema(data_schema, user_input)
else:
data_schema = self.add_suggested_values_to_schema(
data_schema, self._subentry_data["components"][self._component_id]
)
device_name = self._subentry_data[CONF_DEVICE][CONF_NAME]
entity_name: str | None
if entity_name := self._subentry_data["components"][self._component_id].get(
CONF_NAME
):
full_entity_name: str = f"{device_name} {entity_name}"
else:
full_entity_name = device_name
return self.async_show_form(
step_id="mqtt_platform_config",
data_schema=data_schema,
description_placeholders={
"mqtt_device": device_name,
CONF_PLATFORM: platform,
"entity": full_entity_name,
},
errors=errors,
last_step=False,
)
@callback
def _async_create_subentry(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Create a subentry for a new MQTT device."""
device_name = self._subentry_data[CONF_DEVICE][CONF_NAME]
component: dict[str, Any] = next(
iter(self._subentry_data["components"].values())
)
platform = component[CONF_PLATFORM]
entity_name: str | None
if entity_name := component.get(CONF_NAME):
full_entity_name: str = f"{device_name} {entity_name}"
else:
full_entity_name = device_name
return self.async_create_entry(
data=self._subentry_data,
title=self._subentry_data[CONF_DEVICE][CONF_NAME],
description_placeholders={
"entity": full_entity_name,
CONF_PLATFORM: platform,
},
)
async def async_step_availability(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Configure availability options."""
errors: dict[str, str] = {}
validate_field(
"availability_topic",
valid_subscribe_topic,
user_input,
errors,
"invalid_subscribe_topic",
)
validate_field(
"availability_template",
valid_subscribe_topic_template,
user_input,
errors,
"invalid_template",
)
if not errors and user_input is not None:
self._subentry_data.setdefault("availability", MqttAvailabilityData())
self._subentry_data["availability"] = cast(MqttAvailabilityData, user_input)
return await self.async_step_summary_menu()
data_schema = SUBENTRY_AVAILABILITY_SCHEMA
data_schema = self.add_suggested_values_to_schema(
data_schema,
dict(self._subentry_data.setdefault("availability", {}))
if self.source == SOURCE_RECONFIGURE
else user_input,
)
return self.async_show_form(
step_id="availability",
data_schema=data_schema,
errors=errors,
last_step=False,
)
async def async_step_summary_menu(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Show summary menu and decide to add more entities or to finish the flow."""
self._component_id = None
mqtt_device = self._subentry_data[CONF_DEVICE][CONF_NAME]
mqtt_items = ", ".join(
f"{mqtt_device} {component.get(CONF_NAME, '-')}"
for component in self._subentry_data["components"].values()
)
menu_options = [
"entity",
"update_entity",
]
if len(self._subentry_data["components"]) > 1:
menu_options.append("delete_entity")
menu_options.extend(["device", "availability"])
if self._subentry_data != self._get_reconfigure_subentry().data:
menu_options.append("save_changes")
return self.async_show_menu(
step_id="summary_menu",
menu_options=menu_options,
description_placeholders={
"mqtt_device": mqtt_device,
"mqtt_items": mqtt_items,
},
)
async def async_step_save_changes(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Save the changes made to the subentry."""
entry = self._get_entry()
subentry = self._get_reconfigure_subentry()
entity_registry = er.async_get(self.hass)
# When a component is removed from the MQTT device,
# And we save the changes to the subentry,
# we need to clean up stale entity registry entries.
# The component id is used as a part of the unique id of the entity.
for unique_id, platform in [
(
f"{subentry.subentry_id}_{component_id}",
subentry.data["components"][component_id][CONF_PLATFORM],
)
for component_id in subentry.data["components"]
if component_id not in self._subentry_data["components"]
]:
if entity_id := entity_registry.async_get_entity_id(
platform, DOMAIN, unique_id
):
entity_registry.async_remove(entity_id)
return self.async_update_and_abort(
entry,
subentry,
data=self._subentry_data,
title=self._subentry_data[CONF_DEVICE][CONF_NAME],
)
@callback
def async_is_pem_data(data: bytes) -> bool:
"""Return True if data is in PEM format."""
return (
b"-----BEGIN CERTIFICATE-----" in data
or b"-----BEGIN PRIVATE KEY-----" in data
or b"-----BEGIN RSA PRIVATE KEY-----" in data
or b"-----BEGIN ENCRYPTED PRIVATE KEY-----" in data
)
class PEMType(IntEnum):
"""Type of PEM data."""
CERTIFICATE = 1
PRIVATE_KEY = 2
@callback
def async_convert_to_pem(
data: bytes, pem_type: PEMType, password: str | None = None
) -> str | None:
"""Convert data to PEM format."""
try:
if async_is_pem_data(data):
if not password:
# Assume unencrypted PEM encoded private key
return data.decode(DEFAULT_ENCODING)
# Return decrypted PEM encoded private key
return (
load_pem_private_key(data, password=password.encode(DEFAULT_ENCODING))
.private_bytes(
encoding=Encoding.PEM,
format=PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=NoEncryption(),
)
.decode(DEFAULT_ENCODING)
)
# Convert from DER encoding to PEM
if pem_type == PEMType.CERTIFICATE:
return (
load_der_x509_certificate(data)
.public_bytes(
encoding=Encoding.PEM,
)
.decode(DEFAULT_ENCODING)
)
# Assume DER encoded private key
pem_key_data: bytes = load_der_private_key(
data, password.encode(DEFAULT_ENCODING) if password else None
).private_bytes(
encoding=Encoding.PEM,
format=PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=NoEncryption(),
)
return pem_key_data.decode("utf-8")
except (TypeError, ValueError, SSLError):
_LOGGER.exception("Error converting %s file data to PEM format", pem_type.name)
return None
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> bytes:
"""Get file content from uploaded certificate or key file."""
def _proces_uploaded_file() -> bytes:
with process_uploaded_file(hass, id) as file_path:
return file_path.read_bytes()
return await hass.async_add_executor_job(_proces_uploaded_file)
def _validate_pki_file(
file_id: str | None, pem_data: str | None, errors: dict[str, str], error: str
) -> bool:
"""Return False if uploaded file could not be converted to PEM format."""
if file_id and not pem_data:
errors["base"] = error
return False
return True
async def async_get_broker_settings( # noqa: C901
flow: ConfigFlow | OptionsFlow,
fields: OrderedDict[Any, Any],
entry_config: MappingProxyType[str, Any] | None,
user_input: dict[str, Any] | None,
validated_user_input: dict[str, Any],
errors: dict[str, str],
) -> bool:
"""Build the config flow schema to collect the broker settings.
Shows advanced options if one or more are configured
or when the advanced_broker_options checkbox was selected.
Returns True when settings are collected successfully.
"""
hass = flow.hass
advanced_broker_options: bool = False
user_input_basic: dict[str, Any] = {}
current_config: dict[str, Any] = (
entry_config.copy() if entry_config is not None else {}
)
async def _async_validate_broker_settings(
config: dict[str, Any],
user_input: dict[str, Any],
validated_user_input: dict[str, Any],
errors: dict[str, str],
) -> bool:
"""Additional validation on broker settings for better error messages."""
# Get current certificate settings from config entry
certificate: str | None = (
"auto"
if user_input.get(SET_CA_CERT, "off") == "auto"
else config.get(CONF_CERTIFICATE)
if user_input.get(SET_CA_CERT, "off") == "custom"
else None
)
client_certificate: str | None = (
config.get(CONF_CLIENT_CERT) if user_input.get(SET_CLIENT_CERT) else None
)
client_key: str | None = (
config.get(CONF_CLIENT_KEY) if user_input.get(SET_CLIENT_CERT) else None
)
# Prepare entry update with uploaded files
validated_user_input.update(user_input)
client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT)
client_key_id: str | None = user_input.get(CONF_CLIENT_KEY)
# We do not store the private key password in the entry data
client_key_password: str | None = validated_user_input.pop(
CONF_CLIENT_KEY_PASSWORD, None
)
if (client_certificate_id and not client_key_id) or (
not client_certificate_id and client_key_id
):
errors["base"] = "invalid_inclusion"
return False
certificate_id: str | None = user_input.get(CONF_CERTIFICATE)
if certificate_id:
certificate_data_raw = await _get_uploaded_file(hass, certificate_id)
certificate = async_convert_to_pem(
certificate_data_raw, PEMType.CERTIFICATE
)
if not _validate_pki_file(
certificate_id, certificate, errors, "bad_certificate"
):
return False
# Return to form for file upload CA cert or client cert and key
if (
(
not client_certificate
and user_input.get(SET_CLIENT_CERT)
and not client_certificate_id
)
or (
not certificate
and user_input.get(SET_CA_CERT, "off") == "custom"
and not certificate_id
)
or (
user_input.get(CONF_TRANSPORT) == TRANSPORT_WEBSOCKETS
and CONF_WS_PATH not in user_input
)
):
return False
if client_certificate_id:
client_certificate_data = await _get_uploaded_file(
hass, client_certificate_id
)
client_certificate = async_convert_to_pem(
client_certificate_data, PEMType.CERTIFICATE
)
if not _validate_pki_file(
client_certificate_id, client_certificate, errors, "bad_client_cert"
):
return False
if client_key_id:
client_key_data = await _get_uploaded_file(hass, client_key_id)
client_key = async_convert_to_pem(
client_key_data, PEMType.PRIVATE_KEY, password=client_key_password
)
if not _validate_pki_file(
client_key_id, client_key, errors, "client_key_error"
):
return False
certificate_data: dict[str, Any] = {}
if certificate:
certificate_data[CONF_CERTIFICATE] = certificate
if client_certificate:
certificate_data[CONF_CLIENT_CERT] = client_certificate
certificate_data[CONF_CLIENT_KEY] = client_key
validated_user_input.update(certificate_data)
await async_create_certificate_temp_files(hass, certificate_data)
if error := await hass.async_add_executor_job(
check_certicate_chain,
):
errors["base"] = error
return False
if SET_CA_CERT in validated_user_input:
del validated_user_input[SET_CA_CERT]
if SET_CLIENT_CERT in validated_user_input:
del validated_user_input[SET_CLIENT_CERT]
if validated_user_input.get(CONF_TRANSPORT, TRANSPORT_TCP) == TRANSPORT_TCP:
if CONF_WS_PATH in validated_user_input:
del validated_user_input[CONF_WS_PATH]
if CONF_WS_HEADERS in validated_user_input:
del validated_user_input[CONF_WS_HEADERS]
return True
try:
validated_user_input[CONF_WS_HEADERS] = json_loads(
validated_user_input.get(CONF_WS_HEADERS, "{}")
)
schema = vol.Schema({cv.string: cv.template})
schema(validated_user_input[CONF_WS_HEADERS])
except (*JSON_DECODE_EXCEPTIONS, vol.MultipleInvalid):
errors["base"] = "bad_ws_headers"
return False
return True
if user_input:
user_input_basic = user_input.copy()
advanced_broker_options = user_input_basic.get(ADVANCED_OPTIONS, False)
if ADVANCED_OPTIONS not in user_input or advanced_broker_options is False:
if await _async_validate_broker_settings(
current_config,
user_input_basic,
validated_user_input,
errors,
):
return True
# Get defaults settings from previous post
current_broker = user_input_basic.get(CONF_BROKER)
current_port = user_input_basic.get(CONF_PORT, DEFAULT_PORT)
current_user = user_input_basic.get(CONF_USERNAME)
current_pass = user_input_basic.get(CONF_PASSWORD)
else:
# Get default settings from entry (if any)
current_broker = current_config.get(CONF_BROKER)
current_port = current_config.get(CONF_PORT, DEFAULT_PORT)
current_user = current_config.get(CONF_USERNAME)
# Return the sentinel password to avoid exposure
current_entry_pass = current_config.get(CONF_PASSWORD)
current_pass = PWD_NOT_CHANGED if current_entry_pass else None
# Treat the previous post as an update of the current settings
# (if there was a basic broker setup step)
current_config.update(user_input_basic)
# Get default settings for advanced broker options
current_client_id = current_config.get(CONF_CLIENT_ID)
current_keepalive = current_config.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE)
current_ca_certificate = current_config.get(CONF_CERTIFICATE)
current_client_certificate = current_config.get(CONF_CLIENT_CERT)
current_client_key = current_config.get(CONF_CLIENT_KEY)
current_tls_insecure = current_config.get(CONF_TLS_INSECURE, False)
current_protocol = current_config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL)
current_transport = current_config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
current_ws_path = current_config.get(CONF_WS_PATH, DEFAULT_WS_PATH)
current_ws_headers = (
json_dumps(current_config.get(CONF_WS_HEADERS))
if CONF_WS_HEADERS in current_config
else None
)
advanced_broker_options |= bool(
current_client_id
or current_keepalive != DEFAULT_KEEPALIVE
or current_ca_certificate
or current_client_certificate
or current_client_key
or current_tls_insecure
or current_protocol != DEFAULT_PROTOCOL
or current_config.get(SET_CA_CERT, "off") != "off"
or current_config.get(SET_CLIENT_CERT)
or current_transport == TRANSPORT_WEBSOCKETS
)
# Build form
fields[vol.Required(CONF_BROKER, default=current_broker)] = TEXT_SELECTOR
fields[vol.Required(CONF_PORT, default=current_port)] = PORT_SELECTOR
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_user},
)
] = TEXT_SELECTOR
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_pass},
)
] = PASSWORD_SELECTOR
# show advanced options checkbox if requested and
# advanced options are enabled
# or when the defaults of advanced options are overridden
if not advanced_broker_options:
if not flow.show_advanced_options:
return False
fields[
vol.Optional(
ADVANCED_OPTIONS,
)
] = BOOLEAN_SELECTOR
return False
fields[
vol.Optional(
CONF_CLIENT_ID,
description={"suggested_value": current_client_id},
)
] = TEXT_SELECTOR
fields[
vol.Optional(
CONF_KEEPALIVE,
description={"suggested_value": current_keepalive},
)
] = KEEPALIVE_SELECTOR
fields[
vol.Optional(
SET_CLIENT_CERT,
default=current_client_certificate is not None
or current_config.get(SET_CLIENT_CERT) is True,
)
] = BOOLEAN_SELECTOR
if (
current_client_certificate is not None
or current_config.get(SET_CLIENT_CERT) is True
):
fields[
vol.Optional(
CONF_CLIENT_CERT,
description={"suggested_value": user_input_basic.get(CONF_CLIENT_CERT)},
)
] = CERT_UPLOAD_SELECTOR
fields[
vol.Optional(
CONF_CLIENT_KEY,
description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)},
)
] = KEY_UPLOAD_SELECTOR
fields[
vol.Optional(
CONF_CLIENT_KEY_PASSWORD,
description={
"suggested_value": user_input_basic.get(CONF_CLIENT_KEY_PASSWORD)
},
)
] = PASSWORD_SELECTOR
verification_mode = current_config.get(SET_CA_CERT) or (
"off"
if current_ca_certificate is None
else "auto"
if current_ca_certificate == "auto"
else "custom"
)
fields[
vol.Optional(
SET_CA_CERT,
default=verification_mode,
)
] = BROKER_VERIFICATION_SELECTOR
if current_ca_certificate is not None or verification_mode == "custom":
fields[
vol.Optional(
CONF_CERTIFICATE,
user_input_basic.get(CONF_CERTIFICATE),
)
] = CA_CERT_UPLOAD_SELECTOR
fields[
vol.Optional(
CONF_TLS_INSECURE,
description={"suggested_value": current_tls_insecure},
)
] = BOOLEAN_SELECTOR
fields[
vol.Optional(
CONF_PROTOCOL,
description={"suggested_value": current_protocol},
)
] = PROTOCOL_SELECTOR
fields[
vol.Optional(
CONF_TRANSPORT,
description={"suggested_value": current_transport},
)
] = TRANSPORT_SELECTOR
if current_transport == TRANSPORT_WEBSOCKETS:
fields[
vol.Optional(CONF_WS_PATH, description={"suggested_value": current_ws_path})
] = TEXT_SELECTOR
fields[
vol.Optional(
CONF_WS_HEADERS, description={"suggested_value": current_ws_headers}
)
] = WS_HEADERS_SELECTOR
# Show form
return False
def try_connection(
user_input: dict[str, Any],
) -> bool:
"""Test if we can connect to an MQTT broker."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
mqtt_client_setup = MqttClientSetup(user_input)
mqtt_client_setup.setup()
client = mqtt_client_setup.client
result: queue.Queue[bool] = queue.Queue(maxsize=1)
def on_connect(
_mqttc: mqtt.Client,
_userdata: None,
_connect_flags: mqtt.ConnectFlags,
reason_code: mqtt.ReasonCode,
_properties: mqtt.Properties | None = None,
) -> None:
"""Handle connection result."""
result.put(not reason_code.is_failure)
client.on_connect = on_connect
client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT])
client.loop_start()
try:
return result.get(timeout=MQTT_TIMEOUT)
except queue.Empty:
return False
finally:
client.disconnect()
client.loop_stop()
def check_certicate_chain() -> str | None:
"""Check the MQTT certificates."""
if client_certificate := get_file_path(CONF_CLIENT_CERT):
try:
with open(client_certificate, "rb") as client_certificate_file:
load_pem_x509_certificate(client_certificate_file.read())
except ValueError:
return "bad_client_cert"
# Check we can serialize the private key file
if private_key := get_file_path(CONF_CLIENT_KEY):
try:
with open(private_key, "rb") as client_key_file:
load_pem_private_key(client_key_file.read(), password=None)
except (TypeError, ValueError):
return "client_key_error"
# Check the certificate chain
context = SSLContext(PROTOCOL_TLS_CLIENT)
if client_certificate and private_key:
try:
context.load_cert_chain(client_certificate, private_key)
except SSLError:
return "bad_client_cert_key"
# try to load the custom CA file
if (ca_cert := get_file_path(CONF_CERTIFICATE)) is None:
return None
try:
context.load_verify_locations(ca_cert)
except SSLError:
return "bad_certificate"
return None