diff --git a/esphome/components/mqtt/__init__.py b/esphome/components/mqtt/__init__.py index 99f8ad76d8..63d8da5788 100644 --- a/esphome/components/mqtt/__init__.py +++ b/esphome/components/mqtt/__init__.py @@ -41,6 +41,7 @@ from esphome.const import ( CONF_REBOOT_TIMEOUT, CONF_RETAIN, CONF_SHUTDOWN_MESSAGE, + CONF_SKIP_CERT_CN_CHECK, CONF_SSL_FINGERPRINTS, CONF_STATE_TOPIC, CONF_SUBSCRIBE_QOS, @@ -67,7 +68,6 @@ def AUTO_LOAD(): CONF_DISCOVER_IP = "discover_ip" CONF_IDF_SEND_ASYNC = "idf_send_async" -CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" def validate_message_just_topic(value): diff --git a/esphome/const.py b/esphome/const.py index ffa5de2de3..21cf7367de 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -800,6 +800,7 @@ CONF_SHUTDOWN_MESSAGE = "shutdown_message" CONF_SIGNAL_STRENGTH = "signal_strength" CONF_SINGLE_LIGHT_ID = "single_light_id" CONF_SIZE = "size" +CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" CONF_SLEEP_DURATION = "sleep_duration" CONF_SLEEP_PIN = "sleep_pin" CONF_SLEEP_WHEN_DONE = "sleep_when_done" diff --git a/esphome/mqtt.py b/esphome/mqtt.py index 2f90c49025..2403a4a1d9 100644 --- a/esphome/mqtt.py +++ b/esphome/mqtt.py @@ -3,6 +3,7 @@ import hashlib import json import logging import ssl +import tempfile import time import paho.mqtt.client as mqtt @@ -10,6 +11,8 @@ import paho.mqtt.client as mqtt from esphome.const import ( CONF_BROKER, CONF_CERTIFICATE_AUTHORITY, + CONF_CLIENT_CERTIFICATE, + CONF_CLIENT_CERTIFICATE_KEY, CONF_DISCOVERY_PREFIX, CONF_ESPHOME, CONF_LOG_TOPIC, @@ -17,6 +20,7 @@ from esphome.const import ( CONF_NAME, CONF_PASSWORD, CONF_PORT, + CONF_SKIP_CERT_CN_CHECK, CONF_SSL_FINGERPRINTS, CONF_TOPIC, CONF_TOPIC_PREFIX, @@ -102,15 +106,24 @@ def prepare( if config[CONF_MQTT].get(CONF_SSL_FINGERPRINTS) or config[CONF_MQTT].get( CONF_CERTIFICATE_AUTHORITY ): - tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member - client.tls_set( - ca_certs=None, - certfile=None, - keyfile=None, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=tls_version, - ciphers=None, + context = ssl.create_default_context( + cadata=config[CONF_MQTT].get(CONF_CERTIFICATE_AUTHORITY) ) + if config[CONF_MQTT].get(CONF_SKIP_CERT_CN_CHECK): + context.check_hostname = False + if config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE) and config[CONF_MQTT].get( + CONF_CLIENT_CERTIFICATE_KEY + ): + with ( + tempfile.NamedTemporaryFile(mode="w+") as cert_file, + tempfile.NamedTemporaryFile(mode="w+") as key_file, + ): + cert_file.write(config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE)) + cert_file.flush() + key_file.write(config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE_KEY)) + key_file.flush() + context.load_cert_chain(cert_file, key_file) + client.tls_set_context(context) try: host = str(config[CONF_MQTT][CONF_BROKER])