Add test to assert different private key types are accepted and stored correctly in MQTT config flow (#142703)

This commit is contained in:
Jan Bouwhuis 2025-04-11 00:58:48 +02:00 committed by GitHub
parent c6994731b1
commit 32da8c52f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -77,6 +77,16 @@ MOCK_CLIENT_KEY = (
b"## mock client key file ##"
b"\n-----END PRIVATE KEY-----"
)
MOCK_EC_CLIENT_KEY = (
b"-----BEGIN EC PRIVATE KEY-----\n"
b"## mock client key file ##"
b"\n-----END EC PRIVATE KEY-----"
)
MOCK_RSA_CLIENT_KEY = (
b"-----BEGIN RSA PRIVATE KEY-----\n"
b"## mock client key file ##"
b"\n-----END RSA PRIVATE KEY-----"
)
MOCK_ENCRYPTED_CLIENT_KEY = (
b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n"
b"## mock client key file ##\n"
@ -139,7 +149,13 @@ def mock_client_key_check_fail() -> Generator[MagicMock]:
@pytest.fixture
def mock_ssl_context() -> Generator[dict[str, MagicMock]]:
def mock_context_client_key() -> bytes:
"""Mock the client key in the moched ssl context."""
return MOCK_CLIENT_KEY
@pytest.fixture
def mock_ssl_context(mock_context_client_key: bytes) -> Generator[dict[str, MagicMock]]:
"""Mock the SSL context used to load the cert chain and to load verify locations."""
with (
patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context,
@ -156,9 +172,9 @@ def mock_ssl_context() -> Generator[dict[str, MagicMock]]:
"homeassistant.components.mqtt.config_flow.load_der_x509_certificate"
) as mock_der_cert_check,
):
mock_pem_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
mock_pem_key_check().private_bytes.return_value = mock_context_client_key
mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
mock_der_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
mock_der_key_check().private_bytes.return_value = mock_context_client_key
mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
yield {
"context": mock_context,
@ -1952,9 +1968,15 @@ async def test_options_bad_will_message_fails(
}
@pytest.mark.parametrize(
"mock_context_client_key",
[MOCK_CLIENT_KEY, MOCK_EC_CLIENT_KEY, MOCK_RSA_CLIENT_KEY],
)
@pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file")
async def test_try_connection_with_advanced_parameters(
hass: HomeAssistant, mock_try_connection_success: MqttMockPahoClient
hass: HomeAssistant,
mock_try_connection_success: MqttMockPahoClient,
mock_context_client_key: bytes,
) -> None:
"""Test config flow with advanced parameters from config."""
config_entry = MockConfigEntry(
@ -1974,7 +1996,7 @@ async def test_try_connection_with_advanced_parameters(
mqtt.CONF_CERTIFICATE: "auto",
mqtt.CONF_TLS_INSECURE: True,
mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8)"),
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
mqtt.CONF_CLIENT_KEY: mock_context_client_key.decode(encoding="utf-8"),
mqtt.CONF_WS_PATH: "/path/",
mqtt.CONF_WS_HEADERS: {"h1": "v1", "h2": "v2"},
mqtt.CONF_KEEPALIVE: 30,
@ -2047,13 +2069,34 @@ async def test_try_connection_with_advanced_parameters(
# check if tls_insecure_set is called
assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,)
# check if the ca certificate settings were not set during connection test
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
"certfile"
] == mqtt.util.get_file_path(mqtt.CONF_CLIENT_CERT)
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[
"keyfile"
] == mqtt.util.get_file_path(mqtt.CONF_CLIENT_KEY)
def read_file(path: Path) -> bytes:
with open(path, mode="rb") as file:
return file.read()
# check if the client certificate settings saved
client_cert_path = await hass.async_add_executor_job(
mqtt.util.get_file_path, mqtt.CONF_CLIENT_CERT
)
assert (
mock_try_connection_success.tls_set.mock_calls[0].kwargs["certfile"]
== client_cert_path
)
assert (
await hass.async_add_executor_job(read_file, client_cert_path)
== MOCK_CLIENT_CERT
)
client_key_path = await hass.async_add_executor_job(
mqtt.util.get_file_path, mqtt.CONF_CLIENT_KEY
)
assert (
mock_try_connection_success.tls_set.mock_calls[0].kwargs["keyfile"]
== client_key_path
)
assert (
await hass.async_add_executor_job(read_file, client_key_path)
== mock_context_client_key
)
# check if websockets options are set
assert mock_try_connection_success.ws_set_options.mock_calls[0][1] == (