Add test to assert different private key types are accepted and stored correctly in MQTT config flow (#142703)

This commit is contained in:
Jan Bouwhuis 2025-04-11 00:58:48 +02:00 committed by GitHub
parent c6994731b1
commit 32da8c52f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -77,6 +77,16 @@ MOCK_CLIENT_KEY = (
b"## mock client key file ##" b"## mock client key file ##"
b"\n-----END PRIVATE KEY-----" b"\n-----END PRIVATE KEY-----"
) )
MOCK_EC_CLIENT_KEY = (
b"-----BEGIN EC PRIVATE KEY-----\n"
b"## mock client key file ##"
b"\n-----END EC PRIVATE KEY-----"
)
MOCK_RSA_CLIENT_KEY = (
b"-----BEGIN RSA PRIVATE KEY-----\n"
b"## mock client key file ##"
b"\n-----END RSA PRIVATE KEY-----"
)
MOCK_ENCRYPTED_CLIENT_KEY = ( MOCK_ENCRYPTED_CLIENT_KEY = (
b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n" b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n"
b"## mock client key file ##\n" b"## mock client key file ##\n"
@ -139,7 +149,13 @@ def mock_client_key_check_fail() -> Generator[MagicMock]:
@pytest.fixture @pytest.fixture
def mock_ssl_context() -> Generator[dict[str, MagicMock]]: def mock_context_client_key() -> bytes:
"""Mock the client key in the moched ssl context."""
return MOCK_CLIENT_KEY
@pytest.fixture
def mock_ssl_context(mock_context_client_key: bytes) -> Generator[dict[str, MagicMock]]:
"""Mock the SSL context used to load the cert chain and to load verify locations.""" """Mock the SSL context used to load the cert chain and to load verify locations."""
with ( with (
patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context, patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context,
@ -156,9 +172,9 @@ def mock_ssl_context() -> Generator[dict[str, MagicMock]]:
"homeassistant.components.mqtt.config_flow.load_der_x509_certificate" "homeassistant.components.mqtt.config_flow.load_der_x509_certificate"
) as mock_der_cert_check, ) as mock_der_cert_check,
): ):
mock_pem_key_check().private_bytes.return_value = MOCK_CLIENT_KEY mock_pem_key_check().private_bytes.return_value = mock_context_client_key
mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
mock_der_key_check().private_bytes.return_value = MOCK_CLIENT_KEY mock_der_key_check().private_bytes.return_value = mock_context_client_key
mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
yield { yield {
"context": mock_context, "context": mock_context,
@ -1952,9 +1968,15 @@ async def test_options_bad_will_message_fails(
} }
@pytest.mark.parametrize(
"mock_context_client_key",
[MOCK_CLIENT_KEY, MOCK_EC_CLIENT_KEY, MOCK_RSA_CLIENT_KEY],
)
@pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file") @pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file")
async def test_try_connection_with_advanced_parameters( async def test_try_connection_with_advanced_parameters(
hass: HomeAssistant, mock_try_connection_success: MqttMockPahoClient hass: HomeAssistant,
mock_try_connection_success: MqttMockPahoClient,
mock_context_client_key: bytes,
) -> None: ) -> None:
"""Test config flow with advanced parameters from config.""" """Test config flow with advanced parameters from config."""
config_entry = MockConfigEntry( config_entry = MockConfigEntry(
@ -1974,7 +1996,7 @@ async def test_try_connection_with_advanced_parameters(
mqtt.CONF_CERTIFICATE: "auto", mqtt.CONF_CERTIFICATE: "auto",
mqtt.CONF_TLS_INSECURE: True, mqtt.CONF_TLS_INSECURE: True,
mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8)"), mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8)"),
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"), mqtt.CONF_CLIENT_KEY: mock_context_client_key.decode(encoding="utf-8"),
mqtt.CONF_WS_PATH: "/path/", mqtt.CONF_WS_PATH: "/path/",
mqtt.CONF_WS_HEADERS: {"h1": "v1", "h2": "v2"}, mqtt.CONF_WS_HEADERS: {"h1": "v1", "h2": "v2"},
mqtt.CONF_KEEPALIVE: 30, mqtt.CONF_KEEPALIVE: 30,
@ -2047,13 +2069,34 @@ async def test_try_connection_with_advanced_parameters(
# check if tls_insecure_set is called # check if tls_insecure_set is called
assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,) assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,)
# check if the ca certificate settings were not set during connection test def read_file(path: Path) -> bytes:
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[ with open(path, mode="rb") as file:
"certfile" return file.read()
] == mqtt.util.get_file_path(mqtt.CONF_CLIENT_CERT)
assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[ # check if the client certificate settings saved
"keyfile" client_cert_path = await hass.async_add_executor_job(
] == mqtt.util.get_file_path(mqtt.CONF_CLIENT_KEY) mqtt.util.get_file_path, mqtt.CONF_CLIENT_CERT
)
assert (
mock_try_connection_success.tls_set.mock_calls[0].kwargs["certfile"]
== client_cert_path
)
assert (
await hass.async_add_executor_job(read_file, client_cert_path)
== MOCK_CLIENT_CERT
)
client_key_path = await hass.async_add_executor_job(
mqtt.util.get_file_path, mqtt.CONF_CLIENT_KEY
)
assert (
mock_try_connection_success.tls_set.mock_calls[0].kwargs["keyfile"]
== client_key_path
)
assert (
await hass.async_add_executor_job(read_file, client_key_path)
== mock_context_client_key
)
# check if websockets options are set # check if websockets options are set
assert mock_try_connection_success.ws_set_options.mock_calls[0][1] == ( assert mock_try_connection_success.ws_set_options.mock_calls[0][1] == (