core/tests/components/mcp/test_config_flow.py
Allen Porter 0c4cb27fe9
Add OAuth support for Model Context Protocol (mcp) integration (#141874)
* Add authentication support for Model Context Protocol (mcp) integration

* Update homeassistant/components/mcp/application_credentials.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Handle MCP servers with ports

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
2025-03-30 23:14:52 -04:00

657 lines
20 KiB
Python

"""Test the Model Context Protocol config flow."""
import json
from typing import Any
from unittest.mock import AsyncMock, Mock
import httpx
import pytest
import respx
from homeassistant import config_entries
from homeassistant.components.mcp.const import (
CONF_AUTHORIZATION_URL,
CONF_TOKEN_URL,
DOMAIN,
)
from homeassistant.const import CONF_TOKEN, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow
from .conftest import (
AUTH_DOMAIN,
CLIENT_ID,
MCP_SERVER_URL,
OAUTH_AUTHORIZE_URL,
OAUTH_TOKEN_URL,
TEST_API_NAME,
)
from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
MCP_SERVER_BASE_URL = "http://1.1.1.1:8080"
OAUTH_DISCOVERY_ENDPOINT = (
f"{MCP_SERVER_BASE_URL}/.well-known/oauth-authorization-server"
)
OAUTH_SERVER_METADATA_RESPONSE = httpx.Response(
status_code=200,
text=json.dumps(
{
"authorization_endpoint": OAUTH_AUTHORIZE_URL,
"token_endpoint": OAUTH_TOKEN_URL,
}
),
)
CALLBACK_PATH = "/auth/external/callback"
OAUTH_CALLBACK_URL = f"https://example.com{CALLBACK_PATH}"
OAUTH_CODE = "abcd"
OAUTH_TOKEN_PAYLOAD = {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
"type": "Bearer",
"expires_in": 60,
}
def encode_state(hass: HomeAssistant, flow_id: str) -> str:
"""Encode the OAuth JWT."""
return config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": flow_id,
"redirect_uri": OAUTH_CALLBACK_URL,
},
)
async def test_form(
hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock
) -> None:
"""Test the complete configuration flow."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: MCP_SERVER_URL,
}
# Config entry does not have a unique id
assert result["result"]
assert result["result"].unique_id is None
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
"cannot_connect",
),
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
(Exception, "unknown"),
],
)
async def test_form_mcp_client_error(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test we handle different client library errors."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_mcp_client.side_effect = side_effect
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": expected_error}
# Reset the error and make sure the config flow can resume successfully.
mock_mcp_client.side_effect = None
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: MCP_SERVER_URL,
}
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
"user_input",
[
({CONF_URL: "not a url"}),
({CONF_URL: "rtsp://1.1.1.1"}),
],
)
async def test_input_form_validation_error(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
user_input: dict[str, Any],
) -> None:
"""Test we handle invalid auth."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input,
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {CONF_URL: "invalid_url"}
# Reset the error and make sure the config flow can resume successfully.
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: MCP_SERVER_URL,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_unique_url(
hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock
) -> None:
"""Test that the same url cannot be configured twice."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_URL: MCP_SERVER_URL},
title=TEST_API_NAME,
)
config_entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
async def test_server_missing_capbilities(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
) -> None:
"""Test we handle different client library errors."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
response = Mock()
response.serverInfo.name = TEST_API_NAME
response.capabilities.tools = None
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_capabilities"
@respx.mock
async def test_oauth_discovery_flow_without_credentials(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
) -> None:
"""Test for an OAuth discoveryflow for an MCP server where the user has not yet entered credentials."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=OAUTH_SERVER_METADATA_RESPONSE
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
# The config flow will abort and the user will be taken to the application credentials UI
# to enter their credentials.
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_credentials"
async def perform_oauth_flow(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
result: config_entries.ConfigFlowResult,
authorize_url: str = OAUTH_AUTHORIZE_URL,
token_url: str = OAUTH_TOKEN_URL,
) -> config_entries.ConfigFlowResult:
"""Perform the common steps of the OAuth flow.
Expects to be called from the step where the user selects credentials.
"""
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": OAUTH_CALLBACK_URL,
},
)
assert result["url"] == (
f"{authorize_url}?response_type=code&client_id={CLIENT_ID}"
f"&redirect_uri={OAUTH_CALLBACK_URL}"
f"&state={state}"
)
client = await hass_client_no_auth()
resp = await client.get(f"{CALLBACK_PATH}?code={OAUTH_CODE}&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
aioclient_mock.post(
token_url,
json=OAUTH_TOKEN_PAYLOAD,
)
return result
@pytest.mark.parametrize(
("oauth_server_metadata_response", "expected_authorize_url", "expected_token_url"),
[
(OAUTH_SERVER_METADATA_RESPONSE, OAUTH_AUTHORIZE_URL, OAUTH_TOKEN_URL),
(
httpx.Response(
status_code=200,
text=json.dumps(
{
"authorization_endpoint": "/authorize-path",
"token_endpoint": "/token-path",
}
),
),
f"{MCP_SERVER_BASE_URL}/authorize-path",
f"{MCP_SERVER_BASE_URL}/token-path",
),
(
httpx.Response(status_code=404),
f"{MCP_SERVER_BASE_URL}/authorize",
f"{MCP_SERVER_BASE_URL}/token",
),
],
ids=(
"discovery",
"relative_paths",
"no_discovery_metadata",
),
)
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_authentication_flow(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
oauth_server_metadata_response: httpx.Response,
expected_authorize_url: str,
expected_token_url: str,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=oauth_server_metadata_response
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.MENU
assert result["step_id"] == "credentials_choice"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"next_step_id": "pick_implementation",
},
)
assert result["type"] is FlowResultType.EXTERNAL_STEP
result = await perform_oauth_flow(
hass,
aioclient_mock,
hass_client_no_auth,
result,
authorize_url=expected_authorize_url,
token_url=expected_token_url,
)
# Client now accepts credentials
mock_mcp_client.side_effect = None
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
data = result["data"]
token = data.pop(CONF_TOKEN)
assert data == {
"auth_implementation": AUTH_DOMAIN,
CONF_URL: MCP_SERVER_URL,
CONF_AUTHORIZATION_URL: expected_authorize_url,
CONF_TOKEN_URL: expected_token_url,
}
assert token
token.pop("expires_at")
assert token == OAUTH_TOKEN_PAYLOAD
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
"cannot_connect",
),
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
(Exception, "unknown"),
],
)
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_oauth_discovery_failure(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(side_effect=side_effect)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == expected_error
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
"cannot_connect",
),
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
(Exception, "unknown"),
],
)
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_authentication_flow_server_failure_abort(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=OAUTH_SERVER_METADATA_RESPONSE
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.MENU
assert result["step_id"] == "credentials_choice"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"next_step_id": "pick_implementation",
},
)
assert result["type"] is FlowResultType.EXTERNAL_STEP
result = await perform_oauth_flow(
hass,
aioclient_mock,
hass_client_no_auth,
result,
)
# Client fails with an error
mock_mcp_client.side_effect = side_effect
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == expected_error
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_authentication_flow_server_missing_tool_capabilities(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=OAUTH_SERVER_METADATA_RESPONSE
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.MENU
assert result["step_id"] == "credentials_choice"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"next_step_id": "pick_implementation",
},
)
assert result["type"] is FlowResultType.EXTERNAL_STEP
result = await perform_oauth_flow(
hass,
aioclient_mock,
hass_client_no_auth,
result,
)
# Client can now authenticate
mock_mcp_client.side_effect = None
response = Mock()
response.serverInfo.name = TEST_API_NAME
response.capabilities.tools = None
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_capabilities"
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_reauth_flow(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
config_entry_with_auth: MockConfigEntry,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
config_entry_with_auth.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
result = await perform_oauth_flow(hass, aioclient_mock, hass_client_no_auth, result)
# Verify we can connect to the server
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert config_entry_with_auth.unique_id == AUTH_DOMAIN
assert config_entry_with_auth.title == TEST_API_NAME
data = {**config_entry_with_auth.data}
token = data.pop(CONF_TOKEN)
assert data == {
"auth_implementation": AUTH_DOMAIN,
CONF_URL: MCP_SERVER_URL,
CONF_AUTHORIZATION_URL: OAUTH_AUTHORIZE_URL,
CONF_TOKEN_URL: OAUTH_TOKEN_URL,
}
assert token
token.pop("expires_at")
assert token == OAUTH_TOKEN_PAYLOAD
assert len(mock_setup_entry.mock_calls) == 1