mirror of
https://github.com/home-assistant/core.git
synced 2025-05-02 05:07:52 +00:00
Improve certificate handling in MQTT config flow (#137234)
* Improve mqtt broker certificate handling in config flow * Expand test cases
This commit is contained in:
parent
dd21d48ae4
commit
913a4ee9ba
@ -5,14 +5,21 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
|
from enum import IntEnum
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
|
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
from cryptography.hazmat.primitives.serialization import (
|
||||||
from cryptography.x509 import load_pem_x509_certificate
|
Encoding,
|
||||||
|
NoEncryption,
|
||||||
|
PrivateFormat,
|
||||||
|
load_der_private_key,
|
||||||
|
load_pem_private_key,
|
||||||
|
)
|
||||||
|
from cryptography.x509 import load_der_x509_certificate, load_pem_x509_certificate
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.file_upload import process_uploaded_file
|
from homeassistant.components.file_upload import process_uploaded_file
|
||||||
@ -105,6 +112,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
ADDON_SETUP_TIMEOUT = 5
|
ADDON_SETUP_TIMEOUT = 5
|
||||||
ADDON_SETUP_TIMEOUT_ROUNDS = 5
|
ADDON_SETUP_TIMEOUT_ROUNDS = 5
|
||||||
|
|
||||||
|
CONF_CLIENT_KEY_PASSWORD = "client_key_password"
|
||||||
|
|
||||||
MQTT_TIMEOUT = 5
|
MQTT_TIMEOUT = 5
|
||||||
|
|
||||||
ADVANCED_OPTIONS = "advanced_options"
|
ADVANCED_OPTIONS = "advanced_options"
|
||||||
@ -165,12 +174,14 @@ BROKER_VERIFICATION_SELECTOR = SelectSelector(
|
|||||||
|
|
||||||
# mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
# mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
||||||
CA_CERT_UPLOAD_SELECTOR = FileSelector(
|
CA_CERT_UPLOAD_SELECTOR = FileSelector(
|
||||||
FileSelectorConfig(accept=".crt,application/x-x509-ca-cert")
|
FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-ca-cert")
|
||||||
)
|
)
|
||||||
CERT_UPLOAD_SELECTOR = FileSelector(
|
CERT_UPLOAD_SELECTOR = FileSelector(
|
||||||
FileSelectorConfig(accept=".crt,application/x-x509-user-cert")
|
FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-user-cert")
|
||||||
|
)
|
||||||
|
KEY_UPLOAD_SELECTOR = FileSelector(
|
||||||
|
FileSelectorConfig(accept=".pem,.key,.der,.pk8,application/pkcs8")
|
||||||
)
|
)
|
||||||
KEY_UPLOAD_SELECTOR = FileSelector(FileSelectorConfig(accept=".key,application/pkcs8"))
|
|
||||||
|
|
||||||
REAUTH_SCHEMA = vol.Schema(
|
REAUTH_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
@ -710,17 +721,88 @@ class MQTTOptionsFlowHandler(OptionsFlow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> str:
|
@callback
|
||||||
"""Get file content from uploaded file."""
|
def async_is_pem_data(data: bytes) -> bool:
|
||||||
|
"""Return True if data is in PEM format."""
|
||||||
|
return (
|
||||||
|
b"-----BEGIN CERTIFICATE-----" in data
|
||||||
|
or b"-----BEGIN PRIVATE KEY-----" in data
|
||||||
|
or b"-----BEGIN RSA PRIVATE KEY-----" in data
|
||||||
|
or b"-----BEGIN ENCRYPTED PRIVATE KEY-----" in data
|
||||||
|
)
|
||||||
|
|
||||||
def _proces_uploaded_file() -> str:
|
|
||||||
|
class PEMType(IntEnum):
|
||||||
|
"""Type of PEM data."""
|
||||||
|
|
||||||
|
CERTIFICATE = 1
|
||||||
|
PRIVATE_KEY = 2
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_convert_to_pem(
|
||||||
|
data: bytes, pem_type: PEMType, password: str | None = None
|
||||||
|
) -> str | None:
|
||||||
|
"""Convert data to PEM format."""
|
||||||
|
try:
|
||||||
|
if async_is_pem_data(data):
|
||||||
|
if not password:
|
||||||
|
# Assume unencrypted PEM encoded private key
|
||||||
|
return data.decode(DEFAULT_ENCODING)
|
||||||
|
# Return decrypted PEM encoded private key
|
||||||
|
return (
|
||||||
|
load_pem_private_key(data, password=password.encode(DEFAULT_ENCODING))
|
||||||
|
.private_bytes(
|
||||||
|
encoding=Encoding.PEM,
|
||||||
|
format=PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=NoEncryption(),
|
||||||
|
)
|
||||||
|
.decode(DEFAULT_ENCODING)
|
||||||
|
)
|
||||||
|
# Convert from DER encoding to PEM
|
||||||
|
if pem_type == PEMType.CERTIFICATE:
|
||||||
|
return (
|
||||||
|
load_der_x509_certificate(data)
|
||||||
|
.public_bytes(
|
||||||
|
encoding=Encoding.PEM,
|
||||||
|
)
|
||||||
|
.decode(DEFAULT_ENCODING)
|
||||||
|
)
|
||||||
|
# Assume DER encoded private key
|
||||||
|
pem_key_data: bytes = load_der_private_key(
|
||||||
|
data, password.encode(DEFAULT_ENCODING) if password else None
|
||||||
|
).private_bytes(
|
||||||
|
encoding=Encoding.PEM,
|
||||||
|
format=PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=NoEncryption(),
|
||||||
|
)
|
||||||
|
return pem_key_data.decode("utf-8")
|
||||||
|
except (TypeError, ValueError, SSLError):
|
||||||
|
_LOGGER.exception("Error converting %s file data to PEM format", pem_type.name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> bytes:
|
||||||
|
"""Get file content from uploaded certificate or key file."""
|
||||||
|
|
||||||
|
def _proces_uploaded_file() -> bytes:
|
||||||
with process_uploaded_file(hass, id) as file_path:
|
with process_uploaded_file(hass, id) as file_path:
|
||||||
return file_path.read_text(encoding=DEFAULT_ENCODING)
|
return file_path.read_bytes()
|
||||||
|
|
||||||
return await hass.async_add_executor_job(_proces_uploaded_file)
|
return await hass.async_add_executor_job(_proces_uploaded_file)
|
||||||
|
|
||||||
|
|
||||||
async def async_get_broker_settings(
|
def _validate_pki_file(
|
||||||
|
file_id: str | None, pem_data: str | None, errors: dict[str, str], error: str
|
||||||
|
) -> bool:
|
||||||
|
"""Return False if uploaded file could not be converted to PEM format."""
|
||||||
|
if file_id and not pem_data:
|
||||||
|
errors["base"] = error
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_broker_settings( # noqa: C901
|
||||||
flow: ConfigFlow | OptionsFlow,
|
flow: ConfigFlow | OptionsFlow,
|
||||||
fields: OrderedDict[Any, Any],
|
fields: OrderedDict[Any, Any],
|
||||||
entry_config: MappingProxyType[str, Any] | None,
|
entry_config: MappingProxyType[str, Any] | None,
|
||||||
@ -768,6 +850,10 @@ async def async_get_broker_settings(
|
|||||||
validated_user_input.update(user_input)
|
validated_user_input.update(user_input)
|
||||||
client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT)
|
client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT)
|
||||||
client_key_id: str | None = user_input.get(CONF_CLIENT_KEY)
|
client_key_id: str | None = user_input.get(CONF_CLIENT_KEY)
|
||||||
|
# We do not store the private key password in the entry data
|
||||||
|
client_key_password: str | None = validated_user_input.pop(
|
||||||
|
CONF_CLIENT_KEY_PASSWORD, None
|
||||||
|
)
|
||||||
if (client_certificate_id and not client_key_id) or (
|
if (client_certificate_id and not client_key_id) or (
|
||||||
not client_certificate_id and client_key_id
|
not client_certificate_id and client_key_id
|
||||||
):
|
):
|
||||||
@ -775,7 +861,14 @@ async def async_get_broker_settings(
|
|||||||
return False
|
return False
|
||||||
certificate_id: str | None = user_input.get(CONF_CERTIFICATE)
|
certificate_id: str | None = user_input.get(CONF_CERTIFICATE)
|
||||||
if certificate_id:
|
if certificate_id:
|
||||||
certificate = await _get_uploaded_file(hass, certificate_id)
|
certificate_data_raw = await _get_uploaded_file(hass, certificate_id)
|
||||||
|
certificate = async_convert_to_pem(
|
||||||
|
certificate_data_raw, PEMType.CERTIFICATE
|
||||||
|
)
|
||||||
|
if not _validate_pki_file(
|
||||||
|
certificate_id, certificate, errors, "bad_certificate"
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
# Return to form for file upload CA cert or client cert and key
|
# Return to form for file upload CA cert or client cert and key
|
||||||
if (
|
if (
|
||||||
@ -797,9 +890,26 @@ async def async_get_broker_settings(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if client_certificate_id:
|
if client_certificate_id:
|
||||||
client_certificate = await _get_uploaded_file(hass, client_certificate_id)
|
client_certificate_data = await _get_uploaded_file(
|
||||||
|
hass, client_certificate_id
|
||||||
|
)
|
||||||
|
client_certificate = async_convert_to_pem(
|
||||||
|
client_certificate_data, PEMType.CERTIFICATE
|
||||||
|
)
|
||||||
|
if not _validate_pki_file(
|
||||||
|
client_certificate_id, client_certificate, errors, "bad_client_cert"
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
if client_key_id:
|
if client_key_id:
|
||||||
client_key = await _get_uploaded_file(hass, client_key_id)
|
client_key_data = await _get_uploaded_file(hass, client_key_id)
|
||||||
|
client_key = async_convert_to_pem(
|
||||||
|
client_key_data, PEMType.PRIVATE_KEY, password=client_key_password
|
||||||
|
)
|
||||||
|
if not _validate_pki_file(
|
||||||
|
client_key_id, client_key, errors, "client_key_error"
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
certificate_data: dict[str, Any] = {}
|
certificate_data: dict[str, Any] = {}
|
||||||
if certificate:
|
if certificate:
|
||||||
@ -956,6 +1066,14 @@ async def async_get_broker_settings(
|
|||||||
description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)},
|
description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)},
|
||||||
)
|
)
|
||||||
] = KEY_UPLOAD_SELECTOR
|
] = KEY_UPLOAD_SELECTOR
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
CONF_CLIENT_KEY_PASSWORD,
|
||||||
|
description={
|
||||||
|
"suggested_value": user_input_basic.get(CONF_CLIENT_KEY_PASSWORD)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
] = PASSWORD_SELECTOR
|
||||||
verification_mode = current_config.get(SET_CA_CERT) or (
|
verification_mode = current_config.get(SET_CA_CERT) or (
|
||||||
"off"
|
"off"
|
||||||
if current_ca_certificate is None
|
if current_ca_certificate is None
|
||||||
@ -1060,7 +1178,7 @@ def check_certicate_chain() -> str | None:
|
|||||||
with open(private_key, "rb") as client_key_file:
|
with open(private_key, "rb") as client_key_file:
|
||||||
load_pem_private_key(client_key_file.read(), password=None)
|
load_pem_private_key(client_key_file.read(), password=None)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return "bad_client_key"
|
return "client_key_error"
|
||||||
# Check the certificate chain
|
# Check the certificate chain
|
||||||
context = SSLContext(PROTOCOL_TLS_CLIENT)
|
context = SSLContext(PROTOCOL_TLS_CLIENT)
|
||||||
if client_certificate and private_key:
|
if client_certificate and private_key:
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
"client_id": "Client ID (leave empty to randomly generated one)",
|
"client_id": "Client ID (leave empty to randomly generated one)",
|
||||||
"client_cert": "Upload client certificate file",
|
"client_cert": "Upload client certificate file",
|
||||||
"client_key": "Upload private key file",
|
"client_key": "Upload private key file",
|
||||||
|
"client_key_password": "[%key:common::config_flow::data::password%]",
|
||||||
"keepalive": "The time between sending keep alive messages",
|
"keepalive": "The time between sending keep alive messages",
|
||||||
"tls_insecure": "Ignore broker certificate validation",
|
"tls_insecure": "Ignore broker certificate validation",
|
||||||
"protocol": "MQTT protocol",
|
"protocol": "MQTT protocol",
|
||||||
@ -45,6 +46,7 @@
|
|||||||
"client_id": "The unique ID to identify the Home Assistant MQTT API as MQTT client. It is recommended to leave this option blank.",
|
"client_id": "The unique ID to identify the Home Assistant MQTT API as MQTT client. It is recommended to leave this option blank.",
|
||||||
"client_cert": "The client certificate to authenticate against your MQTT broker.",
|
"client_cert": "The client certificate to authenticate against your MQTT broker.",
|
||||||
"client_key": "The private key file that belongs to your client certificate.",
|
"client_key": "The private key file that belongs to your client certificate.",
|
||||||
|
"client_key_password": "The password for the private key file (if set).",
|
||||||
"keepalive": "A value less than 90 seconds is advised.",
|
"keepalive": "A value less than 90 seconds is advised.",
|
||||||
"tls_insecure": "Option to ignore validation of your MQTT broker's certificate.",
|
"tls_insecure": "Option to ignore validation of your MQTT broker's certificate.",
|
||||||
"protocol": "The MQTT protocol your broker operates at. For example 3.1.1.",
|
"protocol": "The MQTT protocol your broker operates at. For example 3.1.1.",
|
||||||
@ -93,8 +95,8 @@
|
|||||||
"bad_will": "Invalid will topic",
|
"bad_will": "Invalid will topic",
|
||||||
"bad_discovery_prefix": "Invalid discovery prefix",
|
"bad_discovery_prefix": "Invalid discovery prefix",
|
||||||
"bad_certificate": "The CA certificate is invalid",
|
"bad_certificate": "The CA certificate is invalid",
|
||||||
"bad_client_cert": "Invalid client certificate, ensure a PEM coded file is supplied",
|
"bad_client_cert": "Invalid client certificate, ensure a valid file is supplied",
|
||||||
"bad_client_key": "Invalid private key, ensure a PEM coded file is supplied without password",
|
"client_key_error": "Invalid private key file or invalid password supplied",
|
||||||
"bad_client_cert_key": "Client certificate and private key are not a valid pair",
|
"bad_client_cert_key": "Client certificate and private key are not a valid pair",
|
||||||
"bad_ws_headers": "Supply valid HTTP headers as a JSON object",
|
"bad_ws_headers": "Supply valid HTTP headers as a JSON object",
|
||||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||||
@ -207,7 +209,7 @@
|
|||||||
"bad_discovery_prefix": "[%key:component::mqtt::config::error::bad_discovery_prefix%]",
|
"bad_discovery_prefix": "[%key:component::mqtt::config::error::bad_discovery_prefix%]",
|
||||||
"bad_certificate": "[%key:component::mqtt::config::error::bad_certificate%]",
|
"bad_certificate": "[%key:component::mqtt::config::error::bad_certificate%]",
|
||||||
"bad_client_cert": "[%key:component::mqtt::config::error::bad_client_cert%]",
|
"bad_client_cert": "[%key:component::mqtt::config::error::bad_client_cert%]",
|
||||||
"bad_client_key": "[%key:component::mqtt::config::error::bad_client_key%]",
|
"client_key_error": "[%key:component::mqtt::config::error::client_key_error%]",
|
||||||
"bad_client_cert_key": "[%key:component::mqtt::config::error::bad_client_cert_key%]",
|
"bad_client_cert_key": "[%key:component::mqtt::config::error::bad_client_cert_key%]",
|
||||||
"bad_ws_headers": "[%key:component::mqtt::config::error::bad_ws_headers%]",
|
"bad_ws_headers": "[%key:component::mqtt::config::error::bad_ws_headers%]",
|
||||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||||
|
@ -40,8 +40,37 @@ ADD_ON_DISCOVERY_INFO = {
|
|||||||
"protocol": "3.1.1",
|
"protocol": "3.1.1",
|
||||||
"ssl": False,
|
"ssl": False,
|
||||||
}
|
}
|
||||||
MOCK_CLIENT_CERT = b"## mock client certificate file ##"
|
|
||||||
MOCK_CLIENT_KEY = b"## mock key file ##"
|
MOCK_CA_CERT = (
|
||||||
|
b"-----BEGIN CERTIFICATE-----\n"
|
||||||
|
b"## mock CA certificate file ##"
|
||||||
|
b"\n-----END CERTIFICATE-----\n"
|
||||||
|
)
|
||||||
|
MOCK_GENERIC_CERT = (
|
||||||
|
b"-----BEGIN CERTIFICATE-----\n"
|
||||||
|
b"## mock generic certificate file ##"
|
||||||
|
b"\n-----END CERTIFICATE-----\n"
|
||||||
|
)
|
||||||
|
MOCK_CA_CERT_DER = b"## mock DER formatted CA certificate file ##\n"
|
||||||
|
MOCK_CLIENT_CERT = (
|
||||||
|
b"-----BEGIN CERTIFICATE-----\n"
|
||||||
|
b"## mock client certificate file ##"
|
||||||
|
b"\n-----END CERTIFICATE-----\n"
|
||||||
|
)
|
||||||
|
MOCK_CLIENT_CERT_DER = b"## mock DER formatted client certificate file ##\n"
|
||||||
|
MOCK_CLIENT_KEY = (
|
||||||
|
b"-----BEGIN PRIVATE KEY-----\n"
|
||||||
|
b"## mock client key file ##"
|
||||||
|
b"\n-----END PRIVATE KEY-----"
|
||||||
|
)
|
||||||
|
MOCK_ENCRYPTED_CLIENT_KEY = (
|
||||||
|
b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n"
|
||||||
|
b"## mock client key file ##\n"
|
||||||
|
b"-----END ENCRYPTED PRIVATE KEY-----"
|
||||||
|
)
|
||||||
|
MOCK_CLIENT_KEY_DER = b"## mock DER formatted key file ##\n"
|
||||||
|
MOCK_ENCRYPTED_CLIENT_KEY_DER = b"## mock DER formatted encrypted key file ##\n"
|
||||||
|
|
||||||
|
|
||||||
MOCK_ENTRY_DATA = {
|
MOCK_ENTRY_DATA = {
|
||||||
mqtt.CONF_BROKER: "test-broker",
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
@ -102,15 +131,27 @@ def mock_ssl_context() -> Generator[dict[str, MagicMock]]:
|
|||||||
patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context,
|
patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context,
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.mqtt.config_flow.load_pem_private_key"
|
"homeassistant.components.mqtt.config_flow.load_pem_private_key"
|
||||||
) as mock_key_check,
|
) as mock_pem_key_check,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.mqtt.config_flow.load_der_private_key"
|
||||||
|
) as mock_der_key_check,
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.mqtt.config_flow.load_pem_x509_certificate"
|
"homeassistant.components.mqtt.config_flow.load_pem_x509_certificate"
|
||||||
) as mock_cert_check,
|
) as mock_pem_cert_check,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.mqtt.config_flow.load_der_x509_certificate"
|
||||||
|
) as mock_der_cert_check,
|
||||||
):
|
):
|
||||||
|
mock_pem_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
|
||||||
|
mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
|
||||||
|
mock_der_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
|
||||||
|
mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
|
||||||
yield {
|
yield {
|
||||||
"context": mock_context,
|
"context": mock_context,
|
||||||
"load_pem_x509_certificate": mock_cert_check,
|
"load_der_private_key": mock_der_key_check,
|
||||||
"load_pem_private_key": mock_key_check,
|
"load_der_x509_certificate": mock_der_cert_check,
|
||||||
|
"load_pem_private_key": mock_pem_key_check,
|
||||||
|
"load_pem_x509_certificate": mock_pem_cert_check,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -180,9 +221,31 @@ def mock_try_connection_time_out() -> Generator[MagicMock]:
|
|||||||
yield mock_client()
|
yield mock_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ca_cert() -> bytes:
|
||||||
|
"""Mock the CA certificate."""
|
||||||
|
return MOCK_CA_CERT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_client_cert() -> bytes:
|
||||||
|
"""Mock the client certificate."""
|
||||||
|
return MOCK_CLIENT_CERT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_client_key() -> bytes:
|
||||||
|
"""Mock the client key."""
|
||||||
|
return MOCK_CLIENT_KEY
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_process_uploaded_file(
|
def mock_process_uploaded_file(
|
||||||
tmp_path: Path, mock_temp_dir: str
|
tmp_path: Path,
|
||||||
|
mock_ca_cert: bytes,
|
||||||
|
mock_client_cert: bytes,
|
||||||
|
mock_client_key: bytes,
|
||||||
|
mock_temp_dir: str,
|
||||||
) -> Generator[MagicMock]:
|
) -> Generator[MagicMock]:
|
||||||
"""Mock upload certificate files."""
|
"""Mock upload certificate files."""
|
||||||
file_id_ca = str(uuid4())
|
file_id_ca = str(uuid4())
|
||||||
@ -195,15 +258,15 @@ def mock_process_uploaded_file(
|
|||||||
) -> Iterator[Path | None]:
|
) -> Iterator[Path | None]:
|
||||||
if file_id == file_id_ca:
|
if file_id == file_id_ca:
|
||||||
with open(tmp_path / "ca.crt", "wb") as cafile:
|
with open(tmp_path / "ca.crt", "wb") as cafile:
|
||||||
cafile.write(b"## mock CA certificate file ##")
|
cafile.write(mock_ca_cert)
|
||||||
yield tmp_path / "ca.crt"
|
yield tmp_path / "ca.crt"
|
||||||
elif file_id == file_id_cert:
|
elif file_id == file_id_cert:
|
||||||
with open(tmp_path / "client.crt", "wb") as certfile:
|
with open(tmp_path / "client.crt", "wb") as certfile:
|
||||||
certfile.write(b"## mock client certificate file ##")
|
certfile.write(mock_client_cert)
|
||||||
yield tmp_path / "client.crt"
|
yield tmp_path / "client.crt"
|
||||||
elif file_id == file_id_key:
|
elif file_id == file_id_key:
|
||||||
with open(tmp_path / "client.key", "wb") as keyfile:
|
with open(tmp_path / "client.key", "wb") as keyfile:
|
||||||
keyfile.write(b"## mock key file ##")
|
keyfile.write(mock_client_key)
|
||||||
yield tmp_path / "client.key"
|
yield tmp_path / "client.key"
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"Unexpected file_id: {file_id}")
|
pytest.fail(f"Unexpected file_id: {file_id}")
|
||||||
@ -1024,12 +1087,37 @@ async def test_option_flow(
|
|||||||
assert yaml_mock.await_count
|
assert yaml_mock.await_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mock_ca_cert", "mock_client_cert", "mock_client_key", "client_key_password"),
|
||||||
|
[
|
||||||
|
(MOCK_GENERIC_CERT, MOCK_GENERIC_CERT, MOCK_CLIENT_KEY, ""),
|
||||||
|
(
|
||||||
|
MOCK_GENERIC_CERT,
|
||||||
|
MOCK_GENERIC_CERT,
|
||||||
|
MOCK_ENCRYPTED_CLIENT_KEY,
|
||||||
|
"very*secret",
|
||||||
|
),
|
||||||
|
(MOCK_CA_CERT_DER, MOCK_CLIENT_CERT_DER, MOCK_CLIENT_KEY_DER, ""),
|
||||||
|
(
|
||||||
|
MOCK_CA_CERT_DER,
|
||||||
|
MOCK_CLIENT_CERT_DER,
|
||||||
|
MOCK_ENCRYPTED_CLIENT_KEY_DER,
|
||||||
|
"very*secret",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"pem_certs_private_key_no_password",
|
||||||
|
"pem_certs_private_key_with_password",
|
||||||
|
"der_certs_private_key_no_password",
|
||||||
|
"der_certs_private_key_with_password",
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_error",
|
"test_error",
|
||||||
[
|
[
|
||||||
"bad_certificate",
|
"bad_certificate",
|
||||||
"bad_client_cert",
|
"bad_client_cert",
|
||||||
"bad_client_key",
|
"client_key_error",
|
||||||
"bad_client_cert_key",
|
"bad_client_cert_key",
|
||||||
"invalid_inclusion",
|
"invalid_inclusion",
|
||||||
None,
|
None,
|
||||||
@ -1042,31 +1130,54 @@ async def test_bad_certificate(
|
|||||||
mock_ssl_context: dict[str, MagicMock],
|
mock_ssl_context: dict[str, MagicMock],
|
||||||
mock_process_uploaded_file: MagicMock,
|
mock_process_uploaded_file: MagicMock,
|
||||||
test_error: str | None,
|
test_error: str | None,
|
||||||
|
client_key_password: str,
|
||||||
|
mock_ca_cert: bytes,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test bad certificate tests."""
|
"""Test bad certificate tests."""
|
||||||
|
|
||||||
|
def _side_effect_on_client_cert(data: bytes) -> MagicMock:
|
||||||
|
"""Raise on client cert only.
|
||||||
|
|
||||||
|
The function is called twice, once for the CA chain
|
||||||
|
and once for the client cert. We only want to raise on a client cert.
|
||||||
|
"""
|
||||||
|
if data == MOCK_CLIENT_CERT_DER:
|
||||||
|
raise ValueError
|
||||||
|
mock_certificate_side_effect = MagicMock()
|
||||||
|
mock_certificate_side_effect().public_bytes.return_value = MOCK_GENERIC_CERT
|
||||||
|
return mock_certificate_side_effect
|
||||||
|
|
||||||
# Mock certificate files
|
# Mock certificate files
|
||||||
file_id = mock_process_uploaded_file.file_id
|
file_id = mock_process_uploaded_file.file_id
|
||||||
|
set_ca_cert = "custom"
|
||||||
|
set_client_cert = True
|
||||||
|
tls_insecure = False
|
||||||
test_input = {
|
test_input = {
|
||||||
mqtt.CONF_BROKER: "another-broker",
|
mqtt.CONF_BROKER: "another-broker",
|
||||||
CONF_PORT: 2345,
|
CONF_PORT: 2345,
|
||||||
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
|
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
|
||||||
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
|
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
|
||||||
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
|
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
|
||||||
"set_ca_cert": True,
|
"client_key_password": client_key_password,
|
||||||
|
"set_ca_cert": set_ca_cert,
|
||||||
"set_client_cert": True,
|
"set_client_cert": True,
|
||||||
}
|
}
|
||||||
set_client_cert = True
|
|
||||||
set_ca_cert = "custom"
|
|
||||||
tls_insecure = False
|
|
||||||
if test_error == "bad_certificate":
|
if test_error == "bad_certificate":
|
||||||
# CA chain is not loading
|
# CA chain is not loading
|
||||||
mock_ssl_context["context"]().load_verify_locations.side_effect = SSLError
|
mock_ssl_context["context"]().load_verify_locations.side_effect = SSLError
|
||||||
|
# Fail on the CA cert if DER encoded
|
||||||
|
mock_ssl_context["load_der_x509_certificate"].side_effect = ValueError
|
||||||
elif test_error == "bad_client_cert":
|
elif test_error == "bad_client_cert":
|
||||||
# Client certificate is invalid
|
# Client certificate is invalid
|
||||||
mock_ssl_context["load_pem_x509_certificate"].side_effect = ValueError
|
mock_ssl_context["load_pem_x509_certificate"].side_effect = ValueError
|
||||||
elif test_error == "bad_client_key":
|
# Fail on the client cert if DER encoded
|
||||||
|
mock_ssl_context[
|
||||||
|
"load_der_x509_certificate"
|
||||||
|
].side_effect = _side_effect_on_client_cert
|
||||||
|
elif test_error == "client_key_error":
|
||||||
# Client key file is invalid
|
# Client key file is invalid
|
||||||
mock_ssl_context["load_pem_private_key"].side_effect = ValueError
|
mock_ssl_context["load_pem_private_key"].side_effect = ValueError
|
||||||
|
mock_ssl_context["load_der_private_key"].side_effect = ValueError
|
||||||
elif test_error == "bad_client_cert_key":
|
elif test_error == "bad_client_cert_key":
|
||||||
# Client key file file and certificate do not pair
|
# Client key file file and certificate do not pair
|
||||||
mock_ssl_context["context"]().load_cert_chain.side_effect = SSLError
|
mock_ssl_context["context"]().load_cert_chain.side_effect = SSLError
|
||||||
@ -2078,8 +2189,8 @@ async def test_setup_with_advanced_settings(
|
|||||||
CONF_USERNAME: "user",
|
CONF_USERNAME: "user",
|
||||||
CONF_PASSWORD: "secret",
|
CONF_PASSWORD: "secret",
|
||||||
mqtt.CONF_KEEPALIVE: 30,
|
mqtt.CONF_KEEPALIVE: 30,
|
||||||
mqtt.CONF_CLIENT_CERT: "## mock client certificate file ##",
|
mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8"),
|
||||||
mqtt.CONF_CLIENT_KEY: "## mock key file ##",
|
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
|
||||||
"tls_insecure": True,
|
"tls_insecure": True,
|
||||||
mqtt.CONF_TRANSPORT: "websockets",
|
mqtt.CONF_TRANSPORT: "websockets",
|
||||||
mqtt.CONF_WS_PATH: "/custom_path/",
|
mqtt.CONF_WS_PATH: "/custom_path/",
|
||||||
@ -2091,6 +2202,155 @@ async def test_setup_with_advanced_settings(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_ssl_context")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mock_ca_cert", "mock_client_cert", "mock_client_key", "client_key_password"),
|
||||||
|
[
|
||||||
|
(MOCK_GENERIC_CERT, MOCK_GENERIC_CERT, MOCK_CLIENT_KEY, ""),
|
||||||
|
(
|
||||||
|
MOCK_GENERIC_CERT,
|
||||||
|
MOCK_GENERIC_CERT,
|
||||||
|
MOCK_ENCRYPTED_CLIENT_KEY,
|
||||||
|
"very*secret",
|
||||||
|
),
|
||||||
|
(MOCK_CA_CERT_DER, MOCK_CLIENT_CERT_DER, MOCK_CLIENT_KEY_DER, ""),
|
||||||
|
(
|
||||||
|
MOCK_CA_CERT_DER,
|
||||||
|
MOCK_CLIENT_CERT_DER,
|
||||||
|
MOCK_ENCRYPTED_CLIENT_KEY_DER,
|
||||||
|
"very*secret",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"pem_certs_private_key_no_password",
|
||||||
|
"pem_certs_private_key_with_password",
|
||||||
|
"der_certs_private_key_no_password",
|
||||||
|
"der_certs_private_key_with_password",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_setup_with_certificates(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_try_connection: MagicMock,
|
||||||
|
mock_process_uploaded_file: MagicMock,
|
||||||
|
client_key_password: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test config flow setup with PEM and DER encoded certificates."""
|
||||||
|
file_id = mock_process_uploaded_file.file_id
|
||||||
|
|
||||||
|
config_entry = MockConfigEntry(
|
||||||
|
domain=mqtt.DOMAIN,
|
||||||
|
version=mqtt.CONFIG_ENTRY_VERSION,
|
||||||
|
minor_version=mqtt.CONFIG_ENTRY_MINOR_VERSION,
|
||||||
|
)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
config_entry,
|
||||||
|
data={
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
CONF_PORT: 1234,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_try_connection.return_value = True
|
||||||
|
|
||||||
|
result = await config_entry.start_reconfigure_flow(hass, show_advanced_options=True)
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
assert result["data_schema"].schema["advanced_options"]
|
||||||
|
|
||||||
|
# first iteration, basic settings
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
CONF_PORT: 2345,
|
||||||
|
CONF_USERNAME: "user",
|
||||||
|
CONF_PASSWORD: "secret",
|
||||||
|
"advanced_options": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
assert "advanced_options" not in result["data_schema"].schema
|
||||||
|
assert result["data_schema"].schema[CONF_CLIENT_ID]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_KEEPALIVE]
|
||||||
|
assert result["data_schema"].schema["set_client_cert"]
|
||||||
|
assert result["data_schema"].schema["set_ca_cert"]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_TLS_INSECURE]
|
||||||
|
assert result["data_schema"].schema[CONF_PROTOCOL]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_TRANSPORT]
|
||||||
|
assert mqtt.CONF_CLIENT_CERT not in result["data_schema"].schema
|
||||||
|
assert mqtt.CONF_CLIENT_KEY not in result["data_schema"].schema
|
||||||
|
|
||||||
|
# second iteration, advanced settings with request for client cert
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
CONF_PORT: 2345,
|
||||||
|
CONF_USERNAME: "user",
|
||||||
|
CONF_PASSWORD: "secret",
|
||||||
|
mqtt.CONF_KEEPALIVE: 30,
|
||||||
|
"set_ca_cert": "custom",
|
||||||
|
"set_client_cert": True,
|
||||||
|
mqtt.CONF_TLS_INSECURE: False,
|
||||||
|
CONF_PROTOCOL: "3.1.1",
|
||||||
|
mqtt.CONF_TRANSPORT: "tcp",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "broker"
|
||||||
|
assert "advanced_options" not in result["data_schema"].schema
|
||||||
|
assert result["data_schema"].schema[CONF_CLIENT_ID]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_KEEPALIVE]
|
||||||
|
assert result["data_schema"].schema["set_client_cert"]
|
||||||
|
assert result["data_schema"].schema["set_ca_cert"]
|
||||||
|
assert result["data_schema"].schema["client_key_password"]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_TLS_INSECURE]
|
||||||
|
assert result["data_schema"].schema[CONF_PROTOCOL]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_CERTIFICATE]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_CLIENT_CERT]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_CLIENT_KEY]
|
||||||
|
assert result["data_schema"].schema[mqtt.CONF_TRANSPORT]
|
||||||
|
|
||||||
|
# third iteration, advanced settings with client cert and key and CA certificate
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
CONF_PORT: 2345,
|
||||||
|
CONF_USERNAME: "user",
|
||||||
|
CONF_PASSWORD: "secret",
|
||||||
|
mqtt.CONF_KEEPALIVE: 30,
|
||||||
|
"set_ca_cert": "custom",
|
||||||
|
"set_client_cert": True,
|
||||||
|
"client_key_password": client_key_password,
|
||||||
|
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
|
||||||
|
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
|
||||||
|
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
|
||||||
|
mqtt.CONF_TLS_INSECURE: False,
|
||||||
|
mqtt.CONF_TRANSPORT: "tcp",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "reconfigure_successful"
|
||||||
|
|
||||||
|
# Check config entry result
|
||||||
|
assert config_entry.data == {
|
||||||
|
mqtt.CONF_BROKER: "test-broker",
|
||||||
|
CONF_PORT: 2345,
|
||||||
|
CONF_USERNAME: "user",
|
||||||
|
CONF_PASSWORD: "secret",
|
||||||
|
mqtt.CONF_KEEPALIVE: 30,
|
||||||
|
mqtt.CONF_CLIENT_CERT: MOCK_GENERIC_CERT.decode(encoding="utf-8"),
|
||||||
|
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
|
||||||
|
"tls_insecure": False,
|
||||||
|
mqtt.CONF_TRANSPORT: "tcp",
|
||||||
|
mqtt.CONF_CERTIFICATE: MOCK_GENERIC_CERT.decode(encoding="utf-8"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file")
|
@pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file")
|
||||||
async def test_change_websockets_transport_to_tcp(
|
async def test_change_websockets_transport_to_tcp(
|
||||||
hass: HomeAssistant, mock_try_connection: MagicMock
|
hass: HomeAssistant, mock_try_connection: MagicMock
|
||||||
|
Loading…
x
Reference in New Issue
Block a user