From 69ecdda5f5cd04642128b2c3fbfa5dac3fbfc7b5 Mon Sep 17 00:00:00 2001 From: starkillerOG Date: Tue, 24 Sep 2024 21:31:52 +0200 Subject: [PATCH] Add SSL Cipher option to aiohttp async_get_clientsession (#126317) Co-authored-by: J. Nick Koston --- homeassistant/helpers/aiohttp_client.py | 30 +++--- homeassistant/util/ssl.py | 57 +++++++---- tests/helpers/test_aiohttp_client.py | 122 +++++++++++++++++++----- tests/util/test_ssl.py | 23 ++--- 4 files changed, 164 insertions(+), 68 deletions(-) diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index d61f889d4b5..2f4c1980468 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -32,11 +32,11 @@ if TYPE_CHECKING: from aiohttp.typedefs import JSONDecoder -DATA_CONNECTOR: HassKey[dict[tuple[bool, int], aiohttp.BaseConnector]] = HassKey( +DATA_CONNECTOR: HassKey[dict[tuple[bool, int, str], aiohttp.BaseConnector]] = HassKey( "aiohttp_connector" ) -DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int], aiohttp.ClientSession]] = HassKey( - "aiohttp_clientsession" +DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int, str], aiohttp.ClientSession]] = ( + HassKey("aiohttp_clientsession") ) SERVER_SOFTWARE = ( @@ -86,12 +86,13 @@ def async_get_clientsession( hass: HomeAssistant, verify_ssl: bool = True, family: socket.AddressFamily = socket.AF_UNSPEC, + ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT, ) -> aiohttp.ClientSession: """Return default aiohttp ClientSession. This method must be run in the event loop. """ - session_key = _make_key(verify_ssl, family) + session_key = _make_key(verify_ssl, family, ssl_cipher) sessions = hass.data.setdefault(DATA_CLIENTSESSION, {}) if session_key not in sessions: @@ -100,6 +101,7 @@ def async_get_clientsession( verify_ssl, auto_cleanup_method=_async_register_default_clientsession_shutdown, family=family, + ssl_cipher=ssl_cipher, ) sessions[session_key] = session else: @@ -115,6 +117,7 @@ def async_create_clientsession( verify_ssl: bool = True, auto_cleanup: bool = True, family: socket.AddressFamily = socket.AF_UNSPEC, + ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT, **kwargs: Any, ) -> aiohttp.ClientSession: """Create a new ClientSession with kwargs, i.e. for cookies. @@ -135,6 +138,7 @@ def async_create_clientsession( verify_ssl, auto_cleanup_method=auto_cleanup_method, family=family, + ssl_cipher=ssl_cipher, **kwargs, ) @@ -146,11 +150,12 @@ def _async_create_clientsession( auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None] | None = None, family: socket.AddressFamily = socket.AF_UNSPEC, + ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT, **kwargs: Any, ) -> aiohttp.ClientSession: """Create a new ClientSession with kwargs, i.e. for cookies.""" clientsession = aiohttp.ClientSession( - connector=_async_get_connector(hass, verify_ssl, family), + connector=_async_get_connector(hass, verify_ssl, family, ssl_cipher), json_serialize=json_dumps, response_class=HassClientResponse, **kwargs, @@ -279,10 +284,12 @@ def _async_register_default_clientsession_shutdown( @callback def _make_key( - verify_ssl: bool = True, family: socket.AddressFamily = socket.AF_UNSPEC -) -> tuple[bool, socket.AddressFamily]: + verify_ssl: bool = True, + family: socket.AddressFamily = socket.AF_UNSPEC, + ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT, +) -> tuple[bool, socket.AddressFamily, ssl_util.SSLCipherList]: """Make a key for connector or session pool.""" - return (verify_ssl, family) + return (verify_ssl, family, ssl_cipher) class HomeAssistantTCPConnector(aiohttp.TCPConnector): @@ -305,21 +312,22 @@ def _async_get_connector( hass: HomeAssistant, verify_ssl: bool = True, family: socket.AddressFamily = socket.AF_UNSPEC, + ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT, ) -> aiohttp.BaseConnector: """Return the connector pool for aiohttp. This method must be run in the event loop. """ - connector_key = _make_key(verify_ssl, family) + connector_key = _make_key(verify_ssl, family, ssl_cipher) connectors = hass.data.setdefault(DATA_CONNECTOR, {}) if connector_key in connectors: return connectors[connector_key] if verify_ssl: - ssl_context: SSLContext = ssl_util.get_default_context() + ssl_context: SSLContext = ssl_util.client_context(ssl_cipher) else: - ssl_context = ssl_util.get_default_no_verify_context() + ssl_context = ssl_util.client_context_no_verify(ssl_cipher) connector = HomeAssistantTCPConnector( family=family, diff --git a/homeassistant/util/ssl.py b/homeassistant/util/ssl.py index 7c1e653ce75..a22fd0c8fb4 100644 --- a/homeassistant/util/ssl.py +++ b/homeassistant/util/ssl.py @@ -15,6 +15,7 @@ class SSLCipherList(StrEnum): PYTHON_DEFAULT = "python_default" INTERMEDIATE = "intermediate" MODERN = "modern" + INSECURE = "insecure" SSL_CIPHER_LISTS = { @@ -58,11 +59,12 @@ SSL_CIPHER_LISTS = { "ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256" ), + SSLCipherList.INSECURE: "DEFAULT:@SECLEVEL=0", } @cache -def _create_no_verify_ssl_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext: +def _client_context_no_verify(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext: # This is a copy of aiohttp's create_default_context() function, with the # ssl verify turned off. # https://github.com/aio-libs/aiohttp/blob/33953f110e97eecc707e1402daa8d543f38a189b/aiohttp/connector.py#L911 @@ -80,16 +82,10 @@ def _create_no_verify_ssl_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLCont return sslcontext -def create_no_verify_ssl_context( +@cache +def _client_context( ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT, ) -> ssl.SSLContext: - """Return an SSL context that does not verify the server certificate.""" - - return _create_no_verify_ssl_context(ssl_cipher_list=ssl_cipher_list) - - -@cache -def _client_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext: # Reuse environment variable definition from requests, since it's already a # requirement. If the environment variable has no value, fall back to using # certs from certifi package. @@ -104,17 +100,19 @@ def _client_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext: return sslcontext -def client_context( - ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT, -) -> ssl.SSLContext: - """Return an SSL context for making requests.""" - - return _client_context(ssl_cipher_list=ssl_cipher_list) - - # Create this only once and reuse it -_DEFAULT_SSL_CONTEXT = client_context() -_DEFAULT_NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context() +_DEFAULT_SSL_CONTEXT = _client_context(SSLCipherList.PYTHON_DEFAULT) +_DEFAULT_NO_VERIFY_SSL_CONTEXT = _client_context_no_verify(SSLCipherList.PYTHON_DEFAULT) +_NO_VERIFY_SSL_CONTEXTS = { + SSLCipherList.INTERMEDIATE: _client_context_no_verify(SSLCipherList.INTERMEDIATE), + SSLCipherList.MODERN: _client_context_no_verify(SSLCipherList.MODERN), + SSLCipherList.INSECURE: _client_context_no_verify(SSLCipherList.INSECURE), +} +_SSL_CONTEXTS = { + SSLCipherList.INTERMEDIATE: _client_context(SSLCipherList.INTERMEDIATE), + SSLCipherList.MODERN: _client_context(SSLCipherList.MODERN), + SSLCipherList.INSECURE: _client_context(SSLCipherList.INSECURE), +} def get_default_context() -> ssl.SSLContext: @@ -127,6 +125,27 @@ def get_default_no_verify_context() -> ssl.SSLContext: return _DEFAULT_NO_VERIFY_SSL_CONTEXT +def client_context_no_verify( + ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT, +) -> ssl.SSLContext: + """Return a SSL context with no verification with a specific ssl cipher.""" + return _NO_VERIFY_SSL_CONTEXTS.get(ssl_cipher_list, _DEFAULT_NO_VERIFY_SSL_CONTEXT) + + +def client_context( + ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT, +) -> ssl.SSLContext: + """Return an SSL context for making requests.""" + return _SSL_CONTEXTS.get(ssl_cipher_list, _DEFAULT_SSL_CONTEXT) + + +def create_no_verify_ssl_context( + ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT, +) -> ssl.SSLContext: + """Return an SSL context that does not verify the server certificate.""" + return _client_context_no_verify(ssl_cipher_list) + + def server_context_modern() -> ssl.SSLContext: """Return an SSL context following the Mozilla recommendations. diff --git a/tests/helpers/test_aiohttp_client.py b/tests/helpers/test_aiohttp_client.py index 4feb03493e9..126ed3f9287 100644 --- a/tests/helpers/test_aiohttp_client.py +++ b/tests/helpers/test_aiohttp_client.py @@ -23,6 +23,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant import homeassistant.helpers.aiohttp_client as client from homeassistant.util.color import RGBColor +from homeassistant.util.ssl import SSLCipherList from tests.common import ( MockConfigEntry, @@ -62,11 +63,14 @@ async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None: """Test init clientsession with ssl.""" client.async_get_clientsession(hass) verify_ssl = True + ssl_cipher = SSLCipherList.PYTHON_DEFAULT family = 0 - client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)] + client_session = hass.data[client.DATA_CLIENTSESSION][ + (verify_ssl, family, ssl_cipher) + ] assert isinstance(client_session, aiohttp.ClientSession) - connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)] assert isinstance(connector, aiohttp.TCPConnector) @@ -74,33 +78,63 @@ async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None: """Test init clientsession without ssl.""" client.async_get_clientsession(hass, verify_ssl=False) verify_ssl = False + ssl_cipher = SSLCipherList.PYTHON_DEFAULT family = 0 - client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)] + client_session = hass.data[client.DATA_CLIENTSESSION][ + (verify_ssl, family, ssl_cipher) + ] assert isinstance(client_session, aiohttp.ClientSession) - connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)] assert isinstance(connector, aiohttp.TCPConnector) @pytest.mark.parametrize( - ("verify_ssl", "expected_family"), + ("verify_ssl", "expected_family", "ssl_cipher"), [ - (True, socket.AF_UNSPEC), - (False, socket.AF_UNSPEC), - (True, socket.AF_INET), - (False, socket.AF_INET), - (True, socket.AF_INET6), - (False, socket.AF_INET6), + (True, socket.AF_UNSPEC, SSLCipherList.PYTHON_DEFAULT), + (True, socket.AF_INET, SSLCipherList.PYTHON_DEFAULT), + (True, socket.AF_INET6, SSLCipherList.PYTHON_DEFAULT), + (True, socket.AF_UNSPEC, SSLCipherList.INTERMEDIATE), + (True, socket.AF_INET, SSLCipherList.INTERMEDIATE), + (True, socket.AF_INET6, SSLCipherList.INTERMEDIATE), + (True, socket.AF_UNSPEC, SSLCipherList.MODERN), + (True, socket.AF_INET, SSLCipherList.MODERN), + (True, socket.AF_INET6, SSLCipherList.MODERN), + (True, socket.AF_UNSPEC, SSLCipherList.INSECURE), + (True, socket.AF_INET, SSLCipherList.INSECURE), + (True, socket.AF_INET6, SSLCipherList.INSECURE), + (False, socket.AF_UNSPEC, SSLCipherList.PYTHON_DEFAULT), + (False, socket.AF_INET, SSLCipherList.PYTHON_DEFAULT), + (False, socket.AF_INET6, SSLCipherList.PYTHON_DEFAULT), + (False, socket.AF_UNSPEC, SSLCipherList.INTERMEDIATE), + (False, socket.AF_INET, SSLCipherList.INTERMEDIATE), + (False, socket.AF_INET6, SSLCipherList.INTERMEDIATE), + (False, socket.AF_UNSPEC, SSLCipherList.MODERN), + (False, socket.AF_INET, SSLCipherList.MODERN), + (False, socket.AF_INET6, SSLCipherList.MODERN), + (False, socket.AF_UNSPEC, SSLCipherList.INSECURE), + (False, socket.AF_INET, SSLCipherList.INSECURE), + (False, socket.AF_INET6, SSLCipherList.INSECURE), ], ) async def test_get_clientsession( - hass: HomeAssistant, verify_ssl: bool, expected_family: int + hass: HomeAssistant, + verify_ssl: bool, + expected_family: int, + ssl_cipher: SSLCipherList, ) -> None: """Test init clientsession combinations.""" - client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family) - client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)] + client.async_get_clientsession( + hass, verify_ssl=verify_ssl, family=expected_family, ssl_cipher=ssl_cipher + ) + client_session = hass.data[client.DATA_CLIENTSESSION][ + (verify_ssl, expected_family, ssl_cipher) + ] assert isinstance(client_session, aiohttp.ClientSession) - connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)] + connector = hass.data[client.DATA_CONNECTOR][ + (verify_ssl, expected_family, ssl_cipher) + ] assert isinstance(connector, aiohttp.TCPConnector) @@ -110,10 +144,11 @@ async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) -> assert isinstance(session, aiohttp.ClientSession) verify_ssl = True + ssl_cipher = SSLCipherList.PYTHON_DEFAULT family = 0 assert client.DATA_CLIENTSESSION not in hass.data - connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)] assert isinstance(connector, aiohttp.TCPConnector) @@ -125,26 +160,61 @@ async def test_create_clientsession_without_ssl_and_cookies( assert isinstance(session, aiohttp.ClientSession) verify_ssl = False + ssl_cipher = SSLCipherList.PYTHON_DEFAULT family = 0 assert client.DATA_CLIENTSESSION not in hass.data - connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)] + connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)] assert isinstance(connector, aiohttp.TCPConnector) @pytest.mark.parametrize( - ("verify_ssl", "expected_family"), - [(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)], + ("verify_ssl", "expected_family", "ssl_cipher"), + [ + (True, 0, SSLCipherList.PYTHON_DEFAULT), + (True, 4, SSLCipherList.PYTHON_DEFAULT), + (True, 6, SSLCipherList.PYTHON_DEFAULT), + (True, 0, SSLCipherList.INTERMEDIATE), + (True, 4, SSLCipherList.INTERMEDIATE), + (True, 6, SSLCipherList.INTERMEDIATE), + (True, 0, SSLCipherList.MODERN), + (True, 4, SSLCipherList.MODERN), + (True, 6, SSLCipherList.MODERN), + (True, 0, SSLCipherList.INSECURE), + (True, 4, SSLCipherList.INSECURE), + (True, 6, SSLCipherList.INSECURE), + (False, 0, SSLCipherList.PYTHON_DEFAULT), + (False, 4, SSLCipherList.PYTHON_DEFAULT), + (False, 6, SSLCipherList.PYTHON_DEFAULT), + (False, 0, SSLCipherList.INTERMEDIATE), + (False, 4, SSLCipherList.INTERMEDIATE), + (False, 6, SSLCipherList.INTERMEDIATE), + (False, 0, SSLCipherList.MODERN), + (False, 4, SSLCipherList.MODERN), + (False, 6, SSLCipherList.MODERN), + (False, 0, SSLCipherList.INSECURE), + (False, 4, SSLCipherList.INSECURE), + (False, 6, SSLCipherList.INSECURE), + ], ) async def test_get_clientsession_cleanup( - hass: HomeAssistant, verify_ssl: bool, expected_family: int + hass: HomeAssistant, + verify_ssl: bool, + expected_family: int, + ssl_cipher: SSLCipherList, ) -> None: """Test init clientsession cleanup.""" - client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family) + client.async_get_clientsession( + hass, verify_ssl=verify_ssl, family=expected_family, ssl_cipher=ssl_cipher + ) - client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)] + client_session = hass.data[client.DATA_CLIENTSESSION][ + (verify_ssl, expected_family, ssl_cipher) + ] assert isinstance(client_session, aiohttp.ClientSession) - connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)] + connector = hass.data[client.DATA_CONNECTOR][ + (verify_ssl, expected_family, ssl_cipher) + ] assert isinstance(connector, aiohttp.TCPConnector) hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE) @@ -158,17 +228,19 @@ async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None: """Test closing clientsession does not work.""" verify_ssl = True + ssl_cipher = SSLCipherList.PYTHON_DEFAULT family = 0 with patch("aiohttp.ClientSession.close") as mock_close: session = client.async_get_clientsession(hass) assert isinstance( - hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)], + hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family, ssl_cipher)], aiohttp.ClientSession, ) assert isinstance( - hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector + hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)], + aiohttp.TCPConnector, ) with pytest.raises(RuntimeError): diff --git a/tests/util/test_ssl.py b/tests/util/test_ssl.py index d0c7ce3bfb6..c0cd2fdba10 100644 --- a/tests/util/test_ssl.py +++ b/tests/util/test_ssl.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, Mock, patch import pytest from homeassistant.util.ssl import ( - SSL_CIPHER_LISTS, SSLCipherList, client_context, create_no_verify_ssl_context, @@ -25,14 +24,13 @@ def test_client_context(mock_sslcontext) -> None: mock_sslcontext.set_ciphers.assert_not_called() client_context(SSLCipherList.MODERN) - mock_sslcontext.set_ciphers.assert_called_with( - SSL_CIPHER_LISTS[SSLCipherList.MODERN] - ) + mock_sslcontext.set_ciphers.assert_not_called() client_context(SSLCipherList.INTERMEDIATE) - mock_sslcontext.set_ciphers.assert_called_with( - SSL_CIPHER_LISTS[SSLCipherList.INTERMEDIATE] - ) + mock_sslcontext.set_ciphers.assert_not_called() + + client_context(SSLCipherList.INSECURE) + mock_sslcontext.set_ciphers.assert_not_called() def test_no_verify_ssl_context(mock_sslcontext) -> None: @@ -42,14 +40,13 @@ def test_no_verify_ssl_context(mock_sslcontext) -> None: mock_sslcontext.set_ciphers.assert_not_called() create_no_verify_ssl_context(SSLCipherList.MODERN) - mock_sslcontext.set_ciphers.assert_called_with( - SSL_CIPHER_LISTS[SSLCipherList.MODERN] - ) + mock_sslcontext.set_ciphers.assert_not_called() create_no_verify_ssl_context(SSLCipherList.INTERMEDIATE) - mock_sslcontext.set_ciphers.assert_called_with( - SSL_CIPHER_LISTS[SSLCipherList.INTERMEDIATE] - ) + mock_sslcontext.set_ciphers.assert_not_called() + + create_no_verify_ssl_context(SSLCipherList.INSECURE) + mock_sslcontext.set_ciphers.assert_not_called() def test_ssl_context_caching() -> None: