Add cipher list option to IMAP config flow (#91896)

* Add cipher list option to IMAP config flow

* Use client_context to get the ssl_context

* Formatting

* Add ssl error no make error handling more specific

* Make ssl_ciper_list an advanced option
This commit is contained in:
Jan Bouwhuis 2023-04-24 15:37:21 +02:00 committed by GitHub
parent c3262ebdb3
commit 3f6541a6db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 114 additions and 12 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Mapping from collections.abc import Mapping
import ssl
from typing import Any from typing import Any
from aioimaplib import AioImapException from aioimaplib import AioImapException
@ -13,18 +14,33 @@ from homeassistant.const import CONF_NAME, CONF_PASSWORD, CONF_PORT, CONF_USERNA
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.selector import (
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
)
from homeassistant.util.ssl import SSLCipherList
from .const import ( from .const import (
CONF_CHARSET, CONF_CHARSET,
CONF_FOLDER, CONF_FOLDER,
CONF_SEARCH, CONF_SEARCH,
CONF_SERVER, CONF_SERVER,
CONF_SSL_CIPHER_LIST,
DEFAULT_PORT, DEFAULT_PORT,
DOMAIN, DOMAIN,
) )
from .coordinator import connect_to_server from .coordinator import connect_to_server
from .errors import InvalidAuth, InvalidFolder from .errors import InvalidAuth, InvalidFolder
CIPHER_SELECTOR = SelectSelector(
SelectSelectorConfig(
options=list(SSLCipherList),
mode=SelectSelectorMode.DROPDOWN,
translation_key=CONF_SSL_CIPHER_LIST,
)
)
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_USERNAME): str, vol.Required(CONF_USERNAME): str,
@ -36,6 +52,11 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(CONF_SEARCH, default="UnSeen UnDeleted"): str, vol.Optional(CONF_SEARCH, default="UnSeen UnDeleted"): str,
} }
) )
CONFIG_SCHEMA_ADVANCED = {
vol.Optional(
CONF_SSL_CIPHER_LIST, default=SSLCipherList.PYTHON_DEFAULT
): CIPHER_SELECTOR
}
OPTIONS_SCHEMA = vol.Schema( OPTIONS_SCHEMA = vol.Schema(
{ {
@ -60,6 +81,11 @@ async def validate_input(user_input: dict[str, Any]) -> dict[str, str]:
errors[CONF_USERNAME] = errors[CONF_PASSWORD] = "invalid_auth" errors[CONF_USERNAME] = errors[CONF_PASSWORD] = "invalid_auth"
except InvalidFolder: except InvalidFolder:
errors[CONF_FOLDER] = "invalid_folder" errors[CONF_FOLDER] = "invalid_folder"
except ssl.SSLError:
# The aioimaplib library 1.0.1 does not raise an ssl.SSLError correctly, but is logged
# See https://github.com/bamthomas/aioimaplib/issues/91
# This handler is added to be able to supply a better error message
errors["base"] = "ssl_error"
except (asyncio.TimeoutError, AioImapException, ConnectionRefusedError): except (asyncio.TimeoutError, AioImapException, ConnectionRefusedError):
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
else: else:
@ -103,8 +129,13 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Handle the initial step.""" """Handle the initial step."""
schema = CONFIG_SCHEMA
if self.show_advanced_options:
schema = schema.extend(CONFIG_SCHEMA_ADVANCED)
if user_input is None: if user_input is None:
return self.async_show_form(step_id="user", data_schema=CONFIG_SCHEMA) return self.async_show_form(step_id="user", data_schema=schema)
self._async_abort_entries_match( self._async_abort_entries_match(
{ {

View File

@ -8,5 +8,6 @@ CONF_SERVER: Final = "server"
CONF_FOLDER: Final = "folder" CONF_FOLDER: Final = "folder"
CONF_SEARCH: Final = "search" CONF_SEARCH: Final = "search"
CONF_CHARSET: Final = "charset" CONF_CHARSET: Final = "charset"
CONF_SSL_CIPHER_LIST: Final = "ssl_cipher_list"
DEFAULT_PORT: Final = 993 DEFAULT_PORT: Final = 993

View File

@ -21,8 +21,16 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryError from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryError
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util.ssl import SSLCipherList, client_context
from .const import CONF_CHARSET, CONF_FOLDER, CONF_SEARCH, CONF_SERVER, DOMAIN from .const import (
CONF_CHARSET,
CONF_FOLDER,
CONF_SEARCH,
CONF_SERVER,
CONF_SSL_CIPHER_LIST,
DOMAIN,
)
from .errors import InvalidAuth, InvalidFolder from .errors import InvalidAuth, InvalidFolder
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -34,8 +42,13 @@ EVENT_IMAP = "imap_content"
async def connect_to_server(data: Mapping[str, Any]) -> IMAP4_SSL: async def connect_to_server(data: Mapping[str, Any]) -> IMAP4_SSL:
"""Connect to imap server and return client.""" """Connect to imap server and return client."""
client = IMAP4_SSL(data[CONF_SERVER], data[CONF_PORT]) ssl_context = client_context(
ssl_cipher_list=data.get(CONF_SSL_CIPHER_LIST, SSLCipherList.PYTHON_DEFAULT)
)
client = IMAP4_SSL(data[CONF_SERVER], data[CONF_PORT], ssl_context=ssl_context)
await client.wait_hello_from_server() await client.wait_hello_from_server()
if client.protocol.state == NONAUTH: if client.protocol.state == NONAUTH:
await client.login(data[CONF_USERNAME], data[CONF_PASSWORD]) await client.login(data[CONF_USERNAME], data[CONF_PASSWORD])
if client.protocol.state not in {AUTH, SELECTED}: if client.protocol.state not in {AUTH, SELECTED}:

View File

@ -9,7 +9,8 @@
"port": "[%key:common::config_flow::data::port%]", "port": "[%key:common::config_flow::data::port%]",
"charset": "Character set", "charset": "Character set",
"folder": "Folder", "folder": "Folder",
"search": "IMAP search" "search": "IMAP search",
"ssl_cipher_list": "SSL cipher list (Advanced)"
} }
}, },
"reauth_confirm": { "reauth_confirm": {
@ -25,7 +26,8 @@
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"invalid_charset": "The specified charset is not supported", "invalid_charset": "The specified charset is not supported",
"invalid_folder": "The selected folder is invalid", "invalid_folder": "The selected folder is invalid",
"invalid_search": "The selected search is invalid" "invalid_search": "The selected search is invalid",
"ssl_error": "An SSL error occurred. Change SSL cipher list and try again"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]", "already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
@ -49,5 +51,14 @@
"invalid_folder": "[%key:component::imap::config::error::invalid_folder%]", "invalid_folder": "[%key:component::imap::config::error::invalid_folder%]",
"invalid_search": "[%key:component::imap::config::error::invalid_search%]" "invalid_search": "[%key:component::imap::config::error::invalid_search%]"
} }
},
"selector": {
"ssl_cipher_list": {
"options": {
"python_default": "Default settings",
"modern": "Modern ciphers",
"intermediate": "Intermediate ciphers"
}
}
} }
} }

View File

@ -1,5 +1,6 @@
"""Test the imap config flow.""" """Test the imap config flow."""
import asyncio import asyncio
import ssl
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from aioimaplib import AioImapException from aioimaplib import AioImapException
@ -113,10 +114,16 @@ async def test_form_invalid_auth(hass: HomeAssistant) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"exc", ("exc", "error"),
[asyncio.TimeoutError, AioImapException("")], [
(asyncio.TimeoutError, "cannot_connect"),
(AioImapException(""), "cannot_connect"),
(ssl.SSLError, "ssl_error"),
],
) )
async def test_form_cannot_connect(hass: HomeAssistant, exc: Exception) -> None: async def test_form_cannot_connect(
hass: HomeAssistant, exc: Exception, error: str
) -> None:
"""Test we handle cannot connect error.""" """Test we handle cannot connect error."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -131,7 +138,7 @@ async def test_form_cannot_connect(hass: HomeAssistant, exc: Exception) -> None:
) )
assert result2["type"] == FlowResultType.FORM assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": "cannot_connect"} assert result2["errors"] == {"base": error}
# make sure we do not lose the user input if somethings gets wrong # make sure we do not lose the user input if somethings gets wrong
assert { assert {
@ -455,3 +462,35 @@ async def test_import_flow_connection_error(hass: HomeAssistant) -> None:
assert result["type"] == FlowResultType.ABORT assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "cannot_connect" assert result["reason"] == "cannot_connect"
@pytest.mark.parametrize("cipher_list", ["python_default", "modern", "intermediate"])
async def test_config_flow_with_cipherlist(
hass: HomeAssistant, mock_setup_entry: AsyncMock, cipher_list: str
) -> None:
"""Test with alternate cipherlist."""
config = MOCK_CONFIG.copy()
config["ssl_cipher_list"] = cipher_list
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_USER, "show_advanced_options": True},
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
with patch(
"homeassistant.components.imap.config_flow.connect_to_server"
) as mock_client:
mock_client.return_value.search.return_value = (
"OK",
[b""],
)
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], config
)
await hass.async_block_till_done()
assert result2["type"] == FlowResultType.CREATE_ENTRY
assert result2["title"] == "email@email.com"
assert result2["data"] == config
assert len(mock_setup_entry.mock_calls) == 1

View File

@ -30,12 +30,19 @@ from .test_config_flow import MOCK_CONFIG
from tests.common import MockConfigEntry, async_capture_events, async_fire_time_changed from tests.common import MockConfigEntry, async_capture_events, async_fire_time_changed
@pytest.mark.parametrize(
"cipher_list", [None, "python_default", "modern", "intermediate"]
)
@pytest.mark.parametrize("imap_has_capability", [True, False], ids=["push", "poll"]) @pytest.mark.parametrize("imap_has_capability", [True, False], ids=["push", "poll"])
async def test_entry_startup_and_unload( async def test_entry_startup_and_unload(
hass: HomeAssistant, mock_imap_protocol: MagicMock hass: HomeAssistant, mock_imap_protocol: MagicMock, cipher_list: str
) -> None: ) -> None:
"""Test imap entry startup and unload with push and polling coordinator.""" """Test imap entry startup and unload with push and polling coordinator and alternate ciphers."""
config_entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG) config = MOCK_CONFIG.copy()
if cipher_list:
config["ssl_cipher_list"] = cipher_list
config_entry = MockConfigEntry(domain=DOMAIN, data=config)
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()