From 32da8c52f7dd60003a7e71bb114274a557e0dfb3 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Fri, 11 Apr 2025 00:58:48 +0200 Subject: [PATCH] Add test to assert different private key types are accepted and stored correctly in MQTT config flow (#142703) --- tests/components/mqtt/test_config_flow.py | 67 +++++++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index c94d692b374..cfc9e0bede0 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -77,6 +77,16 @@ MOCK_CLIENT_KEY = ( b"## mock client key file ##" b"\n-----END PRIVATE KEY-----" ) +MOCK_EC_CLIENT_KEY = ( + b"-----BEGIN EC PRIVATE KEY-----\n" + b"## mock client key file ##" + b"\n-----END EC PRIVATE KEY-----" +) +MOCK_RSA_CLIENT_KEY = ( + b"-----BEGIN RSA PRIVATE KEY-----\n" + b"## mock client key file ##" + b"\n-----END RSA PRIVATE KEY-----" +) MOCK_ENCRYPTED_CLIENT_KEY = ( b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n" b"## mock client key file ##\n" @@ -139,7 +149,13 @@ def mock_client_key_check_fail() -> Generator[MagicMock]: @pytest.fixture -def mock_ssl_context() -> Generator[dict[str, MagicMock]]: +def mock_context_client_key() -> bytes: + """Mock the client key in the moched ssl context.""" + return MOCK_CLIENT_KEY + + +@pytest.fixture +def mock_ssl_context(mock_context_client_key: bytes) -> Generator[dict[str, MagicMock]]: """Mock the SSL context used to load the cert chain and to load verify locations.""" with ( patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context, @@ -156,9 +172,9 @@ def mock_ssl_context() -> Generator[dict[str, MagicMock]]: "homeassistant.components.mqtt.config_flow.load_der_x509_certificate" ) as mock_der_cert_check, ): - mock_pem_key_check().private_bytes.return_value = MOCK_CLIENT_KEY + mock_pem_key_check().private_bytes.return_value = mock_context_client_key mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT - mock_der_key_check().private_bytes.return_value = MOCK_CLIENT_KEY + mock_der_key_check().private_bytes.return_value = mock_context_client_key mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT yield { "context": mock_context, @@ -1952,9 +1968,15 @@ async def test_options_bad_will_message_fails( } +@pytest.mark.parametrize( + "mock_context_client_key", + [MOCK_CLIENT_KEY, MOCK_EC_CLIENT_KEY, MOCK_RSA_CLIENT_KEY], +) @pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file") async def test_try_connection_with_advanced_parameters( - hass: HomeAssistant, mock_try_connection_success: MqttMockPahoClient + hass: HomeAssistant, + mock_try_connection_success: MqttMockPahoClient, + mock_context_client_key: bytes, ) -> None: """Test config flow with advanced parameters from config.""" config_entry = MockConfigEntry( @@ -1974,7 +1996,7 @@ async def test_try_connection_with_advanced_parameters( mqtt.CONF_CERTIFICATE: "auto", mqtt.CONF_TLS_INSECURE: True, mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8)"), - mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"), + mqtt.CONF_CLIENT_KEY: mock_context_client_key.decode(encoding="utf-8"), mqtt.CONF_WS_PATH: "/path/", mqtt.CONF_WS_HEADERS: {"h1": "v1", "h2": "v2"}, mqtt.CONF_KEEPALIVE: 30, @@ -2047,13 +2069,34 @@ async def test_try_connection_with_advanced_parameters( # check if tls_insecure_set is called assert mock_try_connection_success.tls_insecure_set.mock_calls[0][1] == (True,) - # check if the ca certificate settings were not set during connection test - assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[ - "certfile" - ] == mqtt.util.get_file_path(mqtt.CONF_CLIENT_CERT) - assert mock_try_connection_success.tls_set.mock_calls[0].kwargs[ - "keyfile" - ] == mqtt.util.get_file_path(mqtt.CONF_CLIENT_KEY) + def read_file(path: Path) -> bytes: + with open(path, mode="rb") as file: + return file.read() + + # check if the client certificate settings saved + client_cert_path = await hass.async_add_executor_job( + mqtt.util.get_file_path, mqtt.CONF_CLIENT_CERT + ) + assert ( + mock_try_connection_success.tls_set.mock_calls[0].kwargs["certfile"] + == client_cert_path + ) + assert ( + await hass.async_add_executor_job(read_file, client_cert_path) + == MOCK_CLIENT_CERT + ) + + client_key_path = await hass.async_add_executor_job( + mqtt.util.get_file_path, mqtt.CONF_CLIENT_KEY + ) + assert ( + mock_try_connection_success.tls_set.mock_calls[0].kwargs["keyfile"] + == client_key_path + ) + assert ( + await hass.async_add_executor_job(read_file, client_key_path) + == mock_context_client_key + ) # check if websockets options are set assert mock_try_connection_success.ws_set_options.mock_calls[0][1] == (