Improve certificate handling in MQTT config flow (#137234)

* Improve mqtt broker certificate handling in config flow

* Expand test cases
This commit is contained in:
Jan Bouwhuis 2025-03-01 21:14:08 +01:00 committed by GitHub
parent dd21d48ae4
commit 913a4ee9ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 415 additions and 35 deletions

View File

@ -5,14 +5,21 @@ from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from enum import IntEnum
import logging import logging
import queue import queue
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
from types import MappingProxyType from types import MappingProxyType
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import (
from cryptography.x509 import load_pem_x509_certificate Encoding,
NoEncryption,
PrivateFormat,
load_der_private_key,
load_pem_private_key,
)
from cryptography.x509 import load_der_x509_certificate, load_pem_x509_certificate
import voluptuous as vol import voluptuous as vol
from homeassistant.components.file_upload import process_uploaded_file from homeassistant.components.file_upload import process_uploaded_file
@ -105,6 +112,8 @@ _LOGGER = logging.getLogger(__name__)
ADDON_SETUP_TIMEOUT = 5 ADDON_SETUP_TIMEOUT = 5
ADDON_SETUP_TIMEOUT_ROUNDS = 5 ADDON_SETUP_TIMEOUT_ROUNDS = 5
CONF_CLIENT_KEY_PASSWORD = "client_key_password"
MQTT_TIMEOUT = 5 MQTT_TIMEOUT = 5
ADVANCED_OPTIONS = "advanced_options" ADVANCED_OPTIONS = "advanced_options"
@ -165,12 +174,14 @@ BROKER_VERIFICATION_SELECTOR = SelectSelector(
# mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html # mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html
CA_CERT_UPLOAD_SELECTOR = FileSelector( CA_CERT_UPLOAD_SELECTOR = FileSelector(
FileSelectorConfig(accept=".crt,application/x-x509-ca-cert") FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-ca-cert")
) )
CERT_UPLOAD_SELECTOR = FileSelector( CERT_UPLOAD_SELECTOR = FileSelector(
FileSelectorConfig(accept=".crt,application/x-x509-user-cert") FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-user-cert")
)
KEY_UPLOAD_SELECTOR = FileSelector(
FileSelectorConfig(accept=".pem,.key,.der,.pk8,application/pkcs8")
) )
KEY_UPLOAD_SELECTOR = FileSelector(FileSelectorConfig(accept=".key,application/pkcs8"))
REAUTH_SCHEMA = vol.Schema( REAUTH_SCHEMA = vol.Schema(
{ {
@ -710,17 +721,88 @@ class MQTTOptionsFlowHandler(OptionsFlow):
) )
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> str: @callback
"""Get file content from uploaded file.""" def async_is_pem_data(data: bytes) -> bool:
"""Return True if data is in PEM format."""
return (
b"-----BEGIN CERTIFICATE-----" in data
or b"-----BEGIN PRIVATE KEY-----" in data
or b"-----BEGIN RSA PRIVATE KEY-----" in data
or b"-----BEGIN ENCRYPTED PRIVATE KEY-----" in data
)
def _proces_uploaded_file() -> str:
class PEMType(IntEnum):
"""Type of PEM data."""
CERTIFICATE = 1
PRIVATE_KEY = 2
@callback
def async_convert_to_pem(
data: bytes, pem_type: PEMType, password: str | None = None
) -> str | None:
"""Convert data to PEM format."""
try:
if async_is_pem_data(data):
if not password:
# Assume unencrypted PEM encoded private key
return data.decode(DEFAULT_ENCODING)
# Return decrypted PEM encoded private key
return (
load_pem_private_key(data, password=password.encode(DEFAULT_ENCODING))
.private_bytes(
encoding=Encoding.PEM,
format=PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=NoEncryption(),
)
.decode(DEFAULT_ENCODING)
)
# Convert from DER encoding to PEM
if pem_type == PEMType.CERTIFICATE:
return (
load_der_x509_certificate(data)
.public_bytes(
encoding=Encoding.PEM,
)
.decode(DEFAULT_ENCODING)
)
# Assume DER encoded private key
pem_key_data: bytes = load_der_private_key(
data, password.encode(DEFAULT_ENCODING) if password else None
).private_bytes(
encoding=Encoding.PEM,
format=PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=NoEncryption(),
)
return pem_key_data.decode("utf-8")
except (TypeError, ValueError, SSLError):
_LOGGER.exception("Error converting %s file data to PEM format", pem_type.name)
return None
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> bytes:
"""Get file content from uploaded certificate or key file."""
def _proces_uploaded_file() -> bytes:
with process_uploaded_file(hass, id) as file_path: with process_uploaded_file(hass, id) as file_path:
return file_path.read_text(encoding=DEFAULT_ENCODING) return file_path.read_bytes()
return await hass.async_add_executor_job(_proces_uploaded_file) return await hass.async_add_executor_job(_proces_uploaded_file)
async def async_get_broker_settings( def _validate_pki_file(
file_id: str | None, pem_data: str | None, errors: dict[str, str], error: str
) -> bool:
"""Return False if uploaded file could not be converted to PEM format."""
if file_id and not pem_data:
errors["base"] = error
return False
return True
async def async_get_broker_settings( # noqa: C901
flow: ConfigFlow | OptionsFlow, flow: ConfigFlow | OptionsFlow,
fields: OrderedDict[Any, Any], fields: OrderedDict[Any, Any],
entry_config: MappingProxyType[str, Any] | None, entry_config: MappingProxyType[str, Any] | None,
@ -768,6 +850,10 @@ async def async_get_broker_settings(
validated_user_input.update(user_input) validated_user_input.update(user_input)
client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT) client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT)
client_key_id: str | None = user_input.get(CONF_CLIENT_KEY) client_key_id: str | None = user_input.get(CONF_CLIENT_KEY)
# We do not store the private key password in the entry data
client_key_password: str | None = validated_user_input.pop(
CONF_CLIENT_KEY_PASSWORD, None
)
if (client_certificate_id and not client_key_id) or ( if (client_certificate_id and not client_key_id) or (
not client_certificate_id and client_key_id not client_certificate_id and client_key_id
): ):
@ -775,7 +861,14 @@ async def async_get_broker_settings(
return False return False
certificate_id: str | None = user_input.get(CONF_CERTIFICATE) certificate_id: str | None = user_input.get(CONF_CERTIFICATE)
if certificate_id: if certificate_id:
certificate = await _get_uploaded_file(hass, certificate_id) certificate_data_raw = await _get_uploaded_file(hass, certificate_id)
certificate = async_convert_to_pem(
certificate_data_raw, PEMType.CERTIFICATE
)
if not _validate_pki_file(
certificate_id, certificate, errors, "bad_certificate"
):
return False
# Return to form for file upload CA cert or client cert and key # Return to form for file upload CA cert or client cert and key
if ( if (
@ -797,9 +890,26 @@ async def async_get_broker_settings(
return False return False
if client_certificate_id: if client_certificate_id:
client_certificate = await _get_uploaded_file(hass, client_certificate_id) client_certificate_data = await _get_uploaded_file(
hass, client_certificate_id
)
client_certificate = async_convert_to_pem(
client_certificate_data, PEMType.CERTIFICATE
)
if not _validate_pki_file(
client_certificate_id, client_certificate, errors, "bad_client_cert"
):
return False
if client_key_id: if client_key_id:
client_key = await _get_uploaded_file(hass, client_key_id) client_key_data = await _get_uploaded_file(hass, client_key_id)
client_key = async_convert_to_pem(
client_key_data, PEMType.PRIVATE_KEY, password=client_key_password
)
if not _validate_pki_file(
client_key_id, client_key, errors, "client_key_error"
):
return False
certificate_data: dict[str, Any] = {} certificate_data: dict[str, Any] = {}
if certificate: if certificate:
@ -956,6 +1066,14 @@ async def async_get_broker_settings(
description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)}, description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)},
) )
] = KEY_UPLOAD_SELECTOR ] = KEY_UPLOAD_SELECTOR
fields[
vol.Optional(
CONF_CLIENT_KEY_PASSWORD,
description={
"suggested_value": user_input_basic.get(CONF_CLIENT_KEY_PASSWORD)
},
)
] = PASSWORD_SELECTOR
verification_mode = current_config.get(SET_CA_CERT) or ( verification_mode = current_config.get(SET_CA_CERT) or (
"off" "off"
if current_ca_certificate is None if current_ca_certificate is None
@ -1060,7 +1178,7 @@ def check_certicate_chain() -> str | None:
with open(private_key, "rb") as client_key_file: with open(private_key, "rb") as client_key_file:
load_pem_private_key(client_key_file.read(), password=None) load_pem_private_key(client_key_file.read(), password=None)
except (TypeError, ValueError): except (TypeError, ValueError):
return "bad_client_key" return "client_key_error"
# Check the certificate chain # Check the certificate chain
context = SSLContext(PROTOCOL_TLS_CLIENT) context = SSLContext(PROTOCOL_TLS_CLIENT)
if client_certificate and private_key: if client_certificate and private_key:

View File

@ -26,6 +26,7 @@
"client_id": "Client ID (leave empty to randomly generated one)", "client_id": "Client ID (leave empty to randomly generated one)",
"client_cert": "Upload client certificate file", "client_cert": "Upload client certificate file",
"client_key": "Upload private key file", "client_key": "Upload private key file",
"client_key_password": "[%key:common::config_flow::data::password%]",
"keepalive": "The time between sending keep alive messages", "keepalive": "The time between sending keep alive messages",
"tls_insecure": "Ignore broker certificate validation", "tls_insecure": "Ignore broker certificate validation",
"protocol": "MQTT protocol", "protocol": "MQTT protocol",
@ -45,6 +46,7 @@
"client_id": "The unique ID to identify the Home Assistant MQTT API as MQTT client. It is recommended to leave this option blank.", "client_id": "The unique ID to identify the Home Assistant MQTT API as MQTT client. It is recommended to leave this option blank.",
"client_cert": "The client certificate to authenticate against your MQTT broker.", "client_cert": "The client certificate to authenticate against your MQTT broker.",
"client_key": "The private key file that belongs to your client certificate.", "client_key": "The private key file that belongs to your client certificate.",
"client_key_password": "The password for the private key file (if set).",
"keepalive": "A value less than 90 seconds is advised.", "keepalive": "A value less than 90 seconds is advised.",
"tls_insecure": "Option to ignore validation of your MQTT broker's certificate.", "tls_insecure": "Option to ignore validation of your MQTT broker's certificate.",
"protocol": "The MQTT protocol your broker operates at. For example 3.1.1.", "protocol": "The MQTT protocol your broker operates at. For example 3.1.1.",
@ -93,8 +95,8 @@
"bad_will": "Invalid will topic", "bad_will": "Invalid will topic",
"bad_discovery_prefix": "Invalid discovery prefix", "bad_discovery_prefix": "Invalid discovery prefix",
"bad_certificate": "The CA certificate is invalid", "bad_certificate": "The CA certificate is invalid",
"bad_client_cert": "Invalid client certificate, ensure a PEM coded file is supplied", "bad_client_cert": "Invalid client certificate, ensure a valid file is supplied",
"bad_client_key": "Invalid private key, ensure a PEM coded file is supplied without password", "client_key_error": "Invalid private key file or invalid password supplied",
"bad_client_cert_key": "Client certificate and private key are not a valid pair", "bad_client_cert_key": "Client certificate and private key are not a valid pair",
"bad_ws_headers": "Supply valid HTTP headers as a JSON object", "bad_ws_headers": "Supply valid HTTP headers as a JSON object",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
@ -207,7 +209,7 @@
"bad_discovery_prefix": "[%key:component::mqtt::config::error::bad_discovery_prefix%]", "bad_discovery_prefix": "[%key:component::mqtt::config::error::bad_discovery_prefix%]",
"bad_certificate": "[%key:component::mqtt::config::error::bad_certificate%]", "bad_certificate": "[%key:component::mqtt::config::error::bad_certificate%]",
"bad_client_cert": "[%key:component::mqtt::config::error::bad_client_cert%]", "bad_client_cert": "[%key:component::mqtt::config::error::bad_client_cert%]",
"bad_client_key": "[%key:component::mqtt::config::error::bad_client_key%]", "client_key_error": "[%key:component::mqtt::config::error::client_key_error%]",
"bad_client_cert_key": "[%key:component::mqtt::config::error::bad_client_cert_key%]", "bad_client_cert_key": "[%key:component::mqtt::config::error::bad_client_cert_key%]",
"bad_ws_headers": "[%key:component::mqtt::config::error::bad_ws_headers%]", "bad_ws_headers": "[%key:component::mqtt::config::error::bad_ws_headers%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",

View File

@ -40,8 +40,37 @@ ADD_ON_DISCOVERY_INFO = {
"protocol": "3.1.1", "protocol": "3.1.1",
"ssl": False, "ssl": False,
} }
MOCK_CLIENT_CERT = b"## mock client certificate file ##"
MOCK_CLIENT_KEY = b"## mock key file ##" MOCK_CA_CERT = (
b"-----BEGIN CERTIFICATE-----\n"
b"## mock CA certificate file ##"
b"\n-----END CERTIFICATE-----\n"
)
MOCK_GENERIC_CERT = (
b"-----BEGIN CERTIFICATE-----\n"
b"## mock generic certificate file ##"
b"\n-----END CERTIFICATE-----\n"
)
MOCK_CA_CERT_DER = b"## mock DER formatted CA certificate file ##\n"
MOCK_CLIENT_CERT = (
b"-----BEGIN CERTIFICATE-----\n"
b"## mock client certificate file ##"
b"\n-----END CERTIFICATE-----\n"
)
MOCK_CLIENT_CERT_DER = b"## mock DER formatted client certificate file ##\n"
MOCK_CLIENT_KEY = (
b"-----BEGIN PRIVATE KEY-----\n"
b"## mock client key file ##"
b"\n-----END PRIVATE KEY-----"
)
MOCK_ENCRYPTED_CLIENT_KEY = (
b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n"
b"## mock client key file ##\n"
b"-----END ENCRYPTED PRIVATE KEY-----"
)
MOCK_CLIENT_KEY_DER = b"## mock DER formatted key file ##\n"
MOCK_ENCRYPTED_CLIENT_KEY_DER = b"## mock DER formatted encrypted key file ##\n"
MOCK_ENTRY_DATA = { MOCK_ENTRY_DATA = {
mqtt.CONF_BROKER: "test-broker", mqtt.CONF_BROKER: "test-broker",
@ -102,15 +131,27 @@ def mock_ssl_context() -> Generator[dict[str, MagicMock]]:
patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context, patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context,
patch( patch(
"homeassistant.components.mqtt.config_flow.load_pem_private_key" "homeassistant.components.mqtt.config_flow.load_pem_private_key"
) as mock_key_check, ) as mock_pem_key_check,
patch(
"homeassistant.components.mqtt.config_flow.load_der_private_key"
) as mock_der_key_check,
patch( patch(
"homeassistant.components.mqtt.config_flow.load_pem_x509_certificate" "homeassistant.components.mqtt.config_flow.load_pem_x509_certificate"
) as mock_cert_check, ) as mock_pem_cert_check,
patch(
"homeassistant.components.mqtt.config_flow.load_der_x509_certificate"
) as mock_der_cert_check,
): ):
mock_pem_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
mock_der_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
yield { yield {
"context": mock_context, "context": mock_context,
"load_pem_x509_certificate": mock_cert_check, "load_der_private_key": mock_der_key_check,
"load_pem_private_key": mock_key_check, "load_der_x509_certificate": mock_der_cert_check,
"load_pem_private_key": mock_pem_key_check,
"load_pem_x509_certificate": mock_pem_cert_check,
} }
@ -180,9 +221,31 @@ def mock_try_connection_time_out() -> Generator[MagicMock]:
yield mock_client() yield mock_client()
@pytest.fixture
def mock_ca_cert() -> bytes:
"""Mock the CA certificate."""
return MOCK_CA_CERT
@pytest.fixture
def mock_client_cert() -> bytes:
"""Mock the client certificate."""
return MOCK_CLIENT_CERT
@pytest.fixture
def mock_client_key() -> bytes:
"""Mock the client key."""
return MOCK_CLIENT_KEY
@pytest.fixture @pytest.fixture
def mock_process_uploaded_file( def mock_process_uploaded_file(
tmp_path: Path, mock_temp_dir: str tmp_path: Path,
mock_ca_cert: bytes,
mock_client_cert: bytes,
mock_client_key: bytes,
mock_temp_dir: str,
) -> Generator[MagicMock]: ) -> Generator[MagicMock]:
"""Mock upload certificate files.""" """Mock upload certificate files."""
file_id_ca = str(uuid4()) file_id_ca = str(uuid4())
@ -195,15 +258,15 @@ def mock_process_uploaded_file(
) -> Iterator[Path | None]: ) -> Iterator[Path | None]:
if file_id == file_id_ca: if file_id == file_id_ca:
with open(tmp_path / "ca.crt", "wb") as cafile: with open(tmp_path / "ca.crt", "wb") as cafile:
cafile.write(b"## mock CA certificate file ##") cafile.write(mock_ca_cert)
yield tmp_path / "ca.crt" yield tmp_path / "ca.crt"
elif file_id == file_id_cert: elif file_id == file_id_cert:
with open(tmp_path / "client.crt", "wb") as certfile: with open(tmp_path / "client.crt", "wb") as certfile:
certfile.write(b"## mock client certificate file ##") certfile.write(mock_client_cert)
yield tmp_path / "client.crt" yield tmp_path / "client.crt"
elif file_id == file_id_key: elif file_id == file_id_key:
with open(tmp_path / "client.key", "wb") as keyfile: with open(tmp_path / "client.key", "wb") as keyfile:
keyfile.write(b"## mock key file ##") keyfile.write(mock_client_key)
yield tmp_path / "client.key" yield tmp_path / "client.key"
else: else:
pytest.fail(f"Unexpected file_id: {file_id}") pytest.fail(f"Unexpected file_id: {file_id}")
@ -1024,12 +1087,37 @@ async def test_option_flow(
assert yaml_mock.await_count assert yaml_mock.await_count
@pytest.mark.parametrize(
("mock_ca_cert", "mock_client_cert", "mock_client_key", "client_key_password"),
[
(MOCK_GENERIC_CERT, MOCK_GENERIC_CERT, MOCK_CLIENT_KEY, ""),
(
MOCK_GENERIC_CERT,
MOCK_GENERIC_CERT,
MOCK_ENCRYPTED_CLIENT_KEY,
"very*secret",
),
(MOCK_CA_CERT_DER, MOCK_CLIENT_CERT_DER, MOCK_CLIENT_KEY_DER, ""),
(
MOCK_CA_CERT_DER,
MOCK_CLIENT_CERT_DER,
MOCK_ENCRYPTED_CLIENT_KEY_DER,
"very*secret",
),
],
ids=[
"pem_certs_private_key_no_password",
"pem_certs_private_key_with_password",
"der_certs_private_key_no_password",
"der_certs_private_key_with_password",
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_error", "test_error",
[ [
"bad_certificate", "bad_certificate",
"bad_client_cert", "bad_client_cert",
"bad_client_key", "client_key_error",
"bad_client_cert_key", "bad_client_cert_key",
"invalid_inclusion", "invalid_inclusion",
None, None,
@ -1042,31 +1130,54 @@ async def test_bad_certificate(
mock_ssl_context: dict[str, MagicMock], mock_ssl_context: dict[str, MagicMock],
mock_process_uploaded_file: MagicMock, mock_process_uploaded_file: MagicMock,
test_error: str | None, test_error: str | None,
client_key_password: str,
mock_ca_cert: bytes,
) -> None: ) -> None:
"""Test bad certificate tests.""" """Test bad certificate tests."""
def _side_effect_on_client_cert(data: bytes) -> MagicMock:
"""Raise on client cert only.
The function is called twice, once for the CA chain
and once for the client cert. We only want to raise on a client cert.
"""
if data == MOCK_CLIENT_CERT_DER:
raise ValueError
mock_certificate_side_effect = MagicMock()
mock_certificate_side_effect().public_bytes.return_value = MOCK_GENERIC_CERT
return mock_certificate_side_effect
# Mock certificate files # Mock certificate files
file_id = mock_process_uploaded_file.file_id file_id = mock_process_uploaded_file.file_id
set_ca_cert = "custom"
set_client_cert = True
tls_insecure = False
test_input = { test_input = {
mqtt.CONF_BROKER: "another-broker", mqtt.CONF_BROKER: "another-broker",
CONF_PORT: 2345, CONF_PORT: 2345,
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE], mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT], mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY], mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
"set_ca_cert": True, "client_key_password": client_key_password,
"set_ca_cert": set_ca_cert,
"set_client_cert": True, "set_client_cert": True,
} }
set_client_cert = True
set_ca_cert = "custom"
tls_insecure = False
if test_error == "bad_certificate": if test_error == "bad_certificate":
# CA chain is not loading # CA chain is not loading
mock_ssl_context["context"]().load_verify_locations.side_effect = SSLError mock_ssl_context["context"]().load_verify_locations.side_effect = SSLError
# Fail on the CA cert if DER encoded
mock_ssl_context["load_der_x509_certificate"].side_effect = ValueError
elif test_error == "bad_client_cert": elif test_error == "bad_client_cert":
# Client certificate is invalid # Client certificate is invalid
mock_ssl_context["load_pem_x509_certificate"].side_effect = ValueError mock_ssl_context["load_pem_x509_certificate"].side_effect = ValueError
elif test_error == "bad_client_key": # Fail on the client cert if DER encoded
mock_ssl_context[
"load_der_x509_certificate"
].side_effect = _side_effect_on_client_cert
elif test_error == "client_key_error":
# Client key file is invalid # Client key file is invalid
mock_ssl_context["load_pem_private_key"].side_effect = ValueError mock_ssl_context["load_pem_private_key"].side_effect = ValueError
mock_ssl_context["load_der_private_key"].side_effect = ValueError
elif test_error == "bad_client_cert_key": elif test_error == "bad_client_cert_key":
# Client key file file and certificate do not pair # Client key file file and certificate do not pair
mock_ssl_context["context"]().load_cert_chain.side_effect = SSLError mock_ssl_context["context"]().load_cert_chain.side_effect = SSLError
@ -2078,8 +2189,8 @@ async def test_setup_with_advanced_settings(
CONF_USERNAME: "user", CONF_USERNAME: "user",
CONF_PASSWORD: "secret", CONF_PASSWORD: "secret",
mqtt.CONF_KEEPALIVE: 30, mqtt.CONF_KEEPALIVE: 30,
mqtt.CONF_CLIENT_CERT: "## mock client certificate file ##", mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8"),
mqtt.CONF_CLIENT_KEY: "## mock key file ##", mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
"tls_insecure": True, "tls_insecure": True,
mqtt.CONF_TRANSPORT: "websockets", mqtt.CONF_TRANSPORT: "websockets",
mqtt.CONF_WS_PATH: "/custom_path/", mqtt.CONF_WS_PATH: "/custom_path/",
@ -2091,6 +2202,155 @@ async def test_setup_with_advanced_settings(
} }
@pytest.mark.usefixtures("mock_ssl_context")
@pytest.mark.parametrize(
("mock_ca_cert", "mock_client_cert", "mock_client_key", "client_key_password"),
[
(MOCK_GENERIC_CERT, MOCK_GENERIC_CERT, MOCK_CLIENT_KEY, ""),
(
MOCK_GENERIC_CERT,
MOCK_GENERIC_CERT,
MOCK_ENCRYPTED_CLIENT_KEY,
"very*secret",
),
(MOCK_CA_CERT_DER, MOCK_CLIENT_CERT_DER, MOCK_CLIENT_KEY_DER, ""),
(
MOCK_CA_CERT_DER,
MOCK_CLIENT_CERT_DER,
MOCK_ENCRYPTED_CLIENT_KEY_DER,
"very*secret",
),
],
ids=[
"pem_certs_private_key_no_password",
"pem_certs_private_key_with_password",
"der_certs_private_key_no_password",
"der_certs_private_key_with_password",
],
)
async def test_setup_with_certificates(
hass: HomeAssistant,
mock_try_connection: MagicMock,
mock_process_uploaded_file: MagicMock,
client_key_password: str,
) -> None:
"""Test config flow setup with PEM and DER encoded certificates."""
file_id = mock_process_uploaded_file.file_id
config_entry = MockConfigEntry(
domain=mqtt.DOMAIN,
version=mqtt.CONFIG_ENTRY_VERSION,
minor_version=mqtt.CONFIG_ENTRY_MINOR_VERSION,
)
config_entry.add_to_hass(hass)
hass.config_entries.async_update_entry(
config_entry,
data={
mqtt.CONF_BROKER: "test-broker",
CONF_PORT: 1234,
},
)
mock_try_connection.return_value = True
result = await config_entry.start_reconfigure_flow(hass, show_advanced_options=True)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "broker"
assert result["data_schema"].schema["advanced_options"]
# first iteration, basic settings
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "test-broker",
CONF_PORT: 2345,
CONF_USERNAME: "user",
CONF_PASSWORD: "secret",
"advanced_options": True,
},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "broker"
assert "advanced_options" not in result["data_schema"].schema
assert result["data_schema"].schema[CONF_CLIENT_ID]
assert result["data_schema"].schema[mqtt.CONF_KEEPALIVE]
assert result["data_schema"].schema["set_client_cert"]
assert result["data_schema"].schema["set_ca_cert"]
assert result["data_schema"].schema[mqtt.CONF_TLS_INSECURE]
assert result["data_schema"].schema[CONF_PROTOCOL]
assert result["data_schema"].schema[mqtt.CONF_TRANSPORT]
assert mqtt.CONF_CLIENT_CERT not in result["data_schema"].schema
assert mqtt.CONF_CLIENT_KEY not in result["data_schema"].schema
# second iteration, advanced settings with request for client cert
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "test-broker",
CONF_PORT: 2345,
CONF_USERNAME: "user",
CONF_PASSWORD: "secret",
mqtt.CONF_KEEPALIVE: 30,
"set_ca_cert": "custom",
"set_client_cert": True,
mqtt.CONF_TLS_INSECURE: False,
CONF_PROTOCOL: "3.1.1",
mqtt.CONF_TRANSPORT: "tcp",
},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "broker"
assert "advanced_options" not in result["data_schema"].schema
assert result["data_schema"].schema[CONF_CLIENT_ID]
assert result["data_schema"].schema[mqtt.CONF_KEEPALIVE]
assert result["data_schema"].schema["set_client_cert"]
assert result["data_schema"].schema["set_ca_cert"]
assert result["data_schema"].schema["client_key_password"]
assert result["data_schema"].schema[mqtt.CONF_TLS_INSECURE]
assert result["data_schema"].schema[CONF_PROTOCOL]
assert result["data_schema"].schema[mqtt.CONF_CERTIFICATE]
assert result["data_schema"].schema[mqtt.CONF_CLIENT_CERT]
assert result["data_schema"].schema[mqtt.CONF_CLIENT_KEY]
assert result["data_schema"].schema[mqtt.CONF_TRANSPORT]
# third iteration, advanced settings with client cert and key and CA certificate
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
mqtt.CONF_BROKER: "test-broker",
CONF_PORT: 2345,
CONF_USERNAME: "user",
CONF_PASSWORD: "secret",
mqtt.CONF_KEEPALIVE: 30,
"set_ca_cert": "custom",
"set_client_cert": True,
"client_key_password": client_key_password,
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
mqtt.CONF_TLS_INSECURE: False,
mqtt.CONF_TRANSPORT: "tcp",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reconfigure_successful"
# Check config entry result
assert config_entry.data == {
mqtt.CONF_BROKER: "test-broker",
CONF_PORT: 2345,
CONF_USERNAME: "user",
CONF_PASSWORD: "secret",
mqtt.CONF_KEEPALIVE: 30,
mqtt.CONF_CLIENT_CERT: MOCK_GENERIC_CERT.decode(encoding="utf-8"),
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
"tls_insecure": False,
mqtt.CONF_TRANSPORT: "tcp",
mqtt.CONF_CERTIFICATE: MOCK_GENERIC_CERT.decode(encoding="utf-8"),
}
@pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file") @pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file")
async def test_change_websockets_transport_to_tcp( async def test_change_websockets_transport_to_tcp(
hass: HomeAssistant, mock_try_connection: MagicMock hass: HomeAssistant, mock_try_connection: MagicMock