Add OAuth support for Model Context Protocol (mcp) integration (#141874)

* Add authentication support for Model Context Protocol (mcp) integration

* Update homeassistant/components/mcp/application_credentials.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Handle MCP servers with ports

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2025-03-30 20:14:52 -07:00 committed by GitHub
parent 1639163c2e
commit 0c4cb27fe9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 904 additions and 76 deletions

View File

@ -3,12 +3,15 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import cast
from homeassistant.components.application_credentials import AuthorizationServer
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers import config_entry_oauth2_flow, llm
from .const import DOMAIN
from .coordinator import ModelContextProtocolCoordinator
from .application_credentials import authorization_server_context
from .const import CONF_ACCESS_TOKEN, CONF_AUTHORIZATION_URL, CONF_TOKEN_URL, DOMAIN
from .coordinator import ModelContextProtocolCoordinator, TokenManager
from .types import ModelContextProtocolConfigEntry
__all__ = [
@ -20,11 +23,45 @@ __all__ = [
API_PROMPT = "The following tools are available from a remote server named {name}."
async def async_get_config_entry_implementation(
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
) -> config_entry_oauth2_flow.AbstractOAuth2Implementation | None:
"""OAuth implementation for the config entry."""
if "auth_implementation" not in entry.data:
return None
with authorization_server_context(
AuthorizationServer(
authorize_url=entry.data[CONF_AUTHORIZATION_URL],
token_url=entry.data[CONF_TOKEN_URL],
)
):
return await config_entry_oauth2_flow.async_get_config_entry_implementation(
hass, entry
)
async def _create_token_manager(
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
) -> TokenManager | None:
"""Create a OAuth token manager for the config entry if the server requires authentication."""
if not (implementation := await async_get_config_entry_implementation(hass, entry)):
return None
session = config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation)
async def token_manager() -> str:
await session.async_ensure_token_valid()
return cast(str, session.token[CONF_ACCESS_TOKEN])
return token_manager
async def async_setup_entry(
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
) -> bool:
"""Set up Model Context Protocol from a config entry."""
coordinator = ModelContextProtocolCoordinator(hass, entry)
token_manager = await _create_token_manager(hass, entry)
coordinator = ModelContextProtocolCoordinator(hass, entry, token_manager)
await coordinator.async_config_entry_first_refresh()
unsub = llm.async_register_api(

View File

@ -0,0 +1,35 @@
"""Application credentials platform for Model Context Protocol."""
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
import contextvars
from homeassistant.components.application_credentials import AuthorizationServer
from homeassistant.core import HomeAssistant
CONF_ACTIVE_AUTHORIZATION_SERVER = "active_authorization_server"
_mcp_context: contextvars.ContextVar[AuthorizationServer] = contextvars.ContextVar(
"mcp_authorization_server_context"
)
@contextmanager
def authorization_server_context(
authorization_server: AuthorizationServer,
) -> Generator[None]:
"""Context manager for setting the active authorization server."""
token = _mcp_context.set(authorization_server)
try:
yield
finally:
_mcp_context.reset(token)
async def async_get_authorization_server(hass: HomeAssistant) -> AuthorizationServer:
"""Return authorization server, for the default auth implementation."""
if _mcp_context.get() is None:
raise RuntimeError("No MCP authorization server set in context")
return _mcp_context.get()

View File

@ -2,20 +2,29 @@
from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any
from typing import Any, cast
import httpx
import voluptuous as vol
from yarl import URL
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import CONF_URL
from homeassistant.components.application_credentials import AuthorizationServer
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
from homeassistant.const import CONF_TOKEN, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.config_entry_oauth2_flow import (
AbstractOAuth2FlowHandler,
async_get_implementations,
)
from .const import DOMAIN
from .coordinator import mcp_client
from . import async_get_config_entry_implementation
from .application_credentials import authorization_server_context
from .const import CONF_ACCESS_TOKEN, CONF_AUTHORIZATION_URL, CONF_TOKEN_URL, DOMAIN
from .coordinator import TokenManager, mcp_client
_LOGGER = logging.getLogger(__name__)
@ -25,8 +34,62 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
}
)
# OAuth server discovery endpoint for rfc8414
OAUTH_DISCOVERY_ENDPOINT = ".well-known/oauth-authorization-server"
MCP_DISCOVERY_HEADERS = {
"MCP-Protocol-Version": "2025-03-26",
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, Any]:
async def async_discover_oauth_config(
hass: HomeAssistant, mcp_server_url: str
) -> AuthorizationServer:
"""Discover the OAuth configuration for the MCP server.
This implements the functionality in the MCP spec for discovery. If the MCP server URL
is https://api.example.com/v1/mcp, then:
- The authorization base URL is https://api.example.com
- The metadata endpoint MUST be at https://api.example.com/.well-known/oauth-authorization-server
- For servers that do not implement OAuth 2.0 Authorization Server Metadata, the client uses
default paths relative to the authorization base URL.
"""
parsed_url = URL(mcp_server_url)
discovery_endpoint = str(parsed_url.with_path(OAUTH_DISCOVERY_ENDPOINT))
try:
async with httpx.AsyncClient(headers=MCP_DISCOVERY_HEADERS) as client:
response = await client.get(discovery_endpoint)
response.raise_for_status()
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)
raise TimeoutConnectError from error
except httpx.HTTPStatusError as error:
if error.response.status_code == 404:
_LOGGER.info("Authorization Server Metadata not found, using default paths")
return AuthorizationServer(
authorize_url=str(parsed_url.with_path("/authorize")),
token_url=str(parsed_url.with_path("/token")),
)
raise CannotConnect from error
except httpx.HTTPError as error:
_LOGGER.info("Cannot discover OAuth configuration: %s", error)
raise CannotConnect from error
data = response.json()
authorize_url = data["authorization_endpoint"]
token_url = data["token_endpoint"]
if authorize_url.startswith("/"):
authorize_url = str(parsed_url.with_path(authorize_url))
if token_url.startswith("/"):
token_url = str(parsed_url.with_path(token_url))
return AuthorizationServer(
authorize_url=authorize_url,
token_url=token_url,
)
async def validate_input(
hass: HomeAssistant, data: dict[str, Any], token_manager: TokenManager | None = None
) -> dict[str, Any]:
"""Validate the user input and connect to the MCP server."""
url = data[CONF_URL]
try:
@ -34,7 +97,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
except vol.Invalid as error:
raise InvalidUrl from error
try:
async with mcp_client(url) as session:
async with mcp_client(url, token_manager=token_manager) as session:
response = await session.initialize()
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)
@ -56,10 +119,17 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
return {"title": response.serverInfo.name}
class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
class ModelContextProtocolConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
"""Handle a config flow for Model Context Protocol."""
VERSION = 1
DOMAIN = DOMAIN
logger = _LOGGER
def __init__(self) -> None:
"""Initialize the config flow."""
super().__init__()
self.data: dict[str, Any] = {}
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@ -76,7 +146,8 @@ class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
except CannotConnect:
errors["base"] = "cannot_connect"
except InvalidAuth:
return self.async_abort(reason="invalid_auth")
self.data[CONF_URL] = user_input[CONF_URL]
return await self.async_step_auth_discovery()
except MissingCapabilities:
return self.async_abort(reason="missing_capabilities")
except Exception:
@ -90,6 +161,130 @@ class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)
async def async_step_auth_discovery(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the OAuth server discovery step.
Since this OAuth server requires authentication, this step will attempt
to find the OAuth medata then run the OAuth authentication flow.
"""
try:
authorization_server = await async_discover_oauth_config(
self.hass, self.data[CONF_URL]
)
except TimeoutConnectError:
return self.async_abort(reason="timeout_connect")
except CannotConnect:
return self.async_abort(reason="cannot_connect")
except Exception:
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
else:
_LOGGER.info("OAuth configuration: %s", authorization_server)
self.data.update(
{
CONF_AUTHORIZATION_URL: authorization_server.authorize_url,
CONF_TOKEN_URL: authorization_server.token_url,
}
)
return await self.async_step_credentials_choice()
def authorization_server(self) -> AuthorizationServer:
"""Return the authorization server provided by the MCP server."""
return AuthorizationServer(
self.data[CONF_AUTHORIZATION_URL],
self.data[CONF_TOKEN_URL],
)
async def async_step_credentials_choice(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step to ask they user if they would like to add credentials.
This is needed since we can't automatically assume existing credentials
should be used given they may be for another existing server.
"""
with authorization_server_context(self.authorization_server()):
if not await async_get_implementations(self.hass, self.DOMAIN):
return await self.async_step_new_credentials()
return self.async_show_menu(
step_id="credentials_choice",
menu_options=["pick_implementation", "new_credentials"],
)
async def async_step_new_credentials(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step to take the frontend flow to enter new credentials."""
return self.async_abort(reason="missing_credentials")
async def async_step_pick_implementation(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the pick implementation step.
This exists to dynamically set application credentials Authorization Server
based on the values form the OAuth discovery step.
"""
with authorization_server_context(self.authorization_server()):
return await super().async_step_pick_implementation(user_input)
async def async_oauth_create_entry(self, data: dict) -> ConfigFlowResult:
"""Create an entry for the flow.
Ok to override if you want to fetch extra info or even add another step.
"""
config_entry_data = {
**self.data,
**data,
}
async def token_manager() -> str:
return cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
try:
info = await validate_input(self.hass, config_entry_data, token_manager)
except TimeoutConnectError:
return self.async_abort(reason="timeout_connect")
except CannotConnect:
return self.async_abort(reason="cannot_connect")
except MissingCapabilities:
return self.async_abort(reason="missing_capabilities")
except Exception:
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
# Unique id based on the application credentials OAuth Client ID
if self.source == SOURCE_REAUTH:
return self.async_update_reload_and_abort(
self._get_reauth_entry(), data=config_entry_data
)
await self.async_set_unique_id(config_entry_data["auth_implementation"])
return self.async_create_entry(
title=info["title"],
data=config_entry_data,
)
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Perform reauth upon an API authentication error."""
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: Mapping[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm reauth dialog."""
if user_input is None:
return self.async_show_form(step_id="reauth_confirm")
config_entry = self._get_reauth_entry()
self.data = {**config_entry.data}
self.flow_impl = await async_get_config_entry_implementation( # type: ignore[assignment]
self.hass, config_entry
)
return await self.async_step_auth()
class InvalidUrl(HomeAssistantError):
"""Error to indicate the URL format is invalid."""

View File

@ -1,3 +1,7 @@
"""Constants for the Model Context Protocol integration."""
DOMAIN = "mcp"
CONF_ACCESS_TOKEN = "access_token"
CONF_AUTHORIZATION_URL = "authorization_url"
CONF_TOKEN_URL = "token_url"

View File

@ -1,7 +1,7 @@
"""Types for the Model Context Protocol integration."""
import asyncio
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
import datetime
import logging
@ -15,7 +15,7 @@ from voluptuous_openapi import convert_to_voluptuous
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util.json import JsonObjectType
@ -27,16 +27,28 @@ _LOGGER = logging.getLogger(__name__)
UPDATE_INTERVAL = datetime.timedelta(minutes=30)
TIMEOUT = 10
TokenManager = Callable[[], Awaitable[str]]
@asynccontextmanager
async def mcp_client(url: str) -> AsyncGenerator[ClientSession]:
async def mcp_client(
url: str,
token_manager: TokenManager | None = None,
) -> AsyncGenerator[ClientSession]:
"""Create a server-sent event MCP client.
This is an asynccontext manager that exists to wrap other async context managers
so that the coordinator has a single object to manage.
"""
headers: dict[str, str] = {}
if token_manager is not None:
token = await token_manager()
headers["Authorization"] = f"Bearer {token}"
try:
async with sse_client(url=url) as streams, ClientSession(*streams) as session:
async with (
sse_client(url=url, headers=headers) as streams,
ClientSession(*streams) as session,
):
await session.initialize()
yield session
except ExceptionGroup as err:
@ -53,12 +65,14 @@ class ModelContextProtocolTool(llm.Tool):
description: str | None,
parameters: vol.Schema,
server_url: str,
token_manager: TokenManager | None = None,
) -> None:
"""Initialize the tool."""
self.name = name
self.description = description
self.parameters = parameters
self.server_url = server_url
self.token_manager = token_manager
async def async_call(
self,
@ -69,7 +83,7 @@ class ModelContextProtocolTool(llm.Tool):
"""Call the tool."""
try:
async with asyncio.timeout(TIMEOUT):
async with mcp_client(self.server_url) as session:
async with mcp_client(self.server_url, self.token_manager) as session:
result = await session.call_tool(
tool_input.tool_name, tool_input.tool_args
)
@ -87,7 +101,12 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
config_entry: ConfigEntry
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
token_manager: TokenManager | None = None,
) -> None:
"""Initialize ModelContextProtocolCoordinator."""
super().__init__(
hass,
@ -96,6 +115,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
config_entry=config_entry,
update_interval=UPDATE_INTERVAL,
)
self.token_manager = token_manager
async def _async_update_data(self) -> list[llm.Tool]:
"""Fetch data from API endpoint.
@ -105,11 +125,20 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
"""
try:
async with asyncio.timeout(TIMEOUT):
async with mcp_client(self.config_entry.data[CONF_URL]) as session:
async with mcp_client(
self.config_entry.data[CONF_URL], self.token_manager
) as session:
result = await session.list_tools()
except TimeoutError as error:
_LOGGER.debug("Timeout when listing tools: %s", error)
raise UpdateFailed(f"Timeout when listing tools: {error}") from error
except httpx.HTTPStatusError as error:
_LOGGER.debug("Error communicating with API: %s", error)
if error.response.status_code == 401 and self.token_manager is not None:
raise ConfigEntryAuthFailed(
"The MCP server requires authentication"
) from error
raise UpdateFailed(f"Error communicating with API: {error}") from error
except httpx.HTTPError as err:
_LOGGER.debug("Error communicating with API: %s", err)
raise UpdateFailed(f"Error communicating with API: {err}") from err
@ -129,6 +158,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
tool.description,
parameters,
self.config_entry.data[CONF_URL],
self.token_manager,
)
)
return tools

View File

@ -3,6 +3,7 @@
"name": "Model Context Protocol",
"codeowners": ["@allenporter"],
"config_flow": true,
"dependencies": ["application_credentials"],
"documentation": "https://www.home-assistant.io/integrations/mcp",
"iot_class": "local_polling",
"quality_scale": "silver",

View File

@ -44,9 +44,7 @@ rules:
parallel-updates:
status: exempt
comment: Integration does not have platforms.
reauthentication-flow:
status: exempt
comment: Integration does not support authentication.
reauthentication-flow: done
test-coverage: done
# Gold

View File

@ -8,6 +8,15 @@
"data_description": {
"url": "The remote MCP server URL for the SSE endpoint, for example http://example/sse"
}
},
"pick_implementation": {
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]",
"data": {
"implementation": "Credentials"
},
"data_description": {
"implementation": "The credentials to use for the OAuth2 flow"
}
}
},
"error": {
@ -17,9 +26,15 @@
"invalid_url": "Must be a valid MCP server URL e.g. https://example.com/sse"
},
"abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"missing_capabilities": "The MCP server does not support a required capability (Tools)",
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
"missing_credentials": "[%key:common::config_flow::abort::oauth2_missing_credentials%]",
"reauth_account_mismatch": "The authenticated user does not match the MCP Server user that needed re-authentication.",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"timeout_connect": "[%key:common::config_flow::error::timeout_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
}
}
}

View File

@ -19,6 +19,7 @@ APPLICATION_CREDENTIALS = [
"iotty",
"lametric",
"lyric",
"mcp",
"microbees",
"monzo",
"myuplink",

View File

@ -1,17 +1,34 @@
"""Common fixtures for the Model Context Protocol tests."""
from collections.abc import Generator
import datetime
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.mcp.const import DOMAIN
from homeassistant.const import CONF_URL
from homeassistant.components.application_credentials import (
ClientCredential,
async_import_client_credential,
)
from homeassistant.components.mcp.const import (
CONF_ACCESS_TOKEN,
CONF_AUTHORIZATION_URL,
CONF_TOKEN_URL,
DOMAIN,
)
from homeassistant.const import CONF_TOKEN, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
TEST_API_NAME = "Memory Server"
MCP_SERVER_URL = "http://1.1.1.1:8080/sse"
CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
AUTH_DOMAIN = "some-auth-domain"
OAUTH_AUTHORIZE_URL = "https://example-auth-server.com/authorize-path"
OAUTH_TOKEN_URL = "https://example-auth-server.com/token-path"
@pytest.fixture
@ -29,6 +46,7 @@ def mock_mcp_client() -> Generator[AsyncMock]:
with (
patch("homeassistant.components.mcp.coordinator.sse_client"),
patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session,
patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1),
):
yield mock_session.return_value.__aenter__
@ -43,3 +61,47 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
)
config_entry.add_to_hass(hass)
return config_entry
@pytest.fixture(name="credential")
async def mock_credential(hass: HomeAssistant) -> None:
"""Fixture that provides the ClientCredential for the test."""
assert await async_setup_component(hass, "application_credentials", {})
await async_import_client_credential(
hass,
DOMAIN,
ClientCredential(CLIENT_ID, CLIENT_SECRET),
AUTH_DOMAIN,
)
@pytest.fixture(name="config_entry_token_expiration")
def mock_config_entry_token_expiration() -> datetime.datetime:
"""Fixture to mock the token expiration."""
return datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)
@pytest.fixture(name="config_entry_with_auth")
def mock_config_entry_with_auth(
hass: HomeAssistant,
config_entry_token_expiration: datetime.datetime,
) -> MockConfigEntry:
"""Fixture to load the integration with authentication."""
config_entry = MockConfigEntry(
domain=DOMAIN,
unique_id=AUTH_DOMAIN,
data={
"auth_implementation": AUTH_DOMAIN,
CONF_URL: MCP_SERVER_URL,
CONF_AUTHORIZATION_URL: OAUTH_AUTHORIZE_URL,
CONF_TOKEN_URL: OAUTH_TOKEN_URL,
CONF_TOKEN: {
CONF_ACCESS_TOKEN: "test-access-token",
"refresh_token": "test-refresh-token",
"expires_at": config_entry_token_expiration.timestamp(),
},
},
title=TEST_API_NAME,
)
config_entry.add_to_hass(hass)
return config_entry

View File

@ -1,20 +1,70 @@
"""Test the Model Context Protocol config flow."""
import json
from typing import Any
from unittest.mock import AsyncMock, Mock
import httpx
import pytest
import respx
from homeassistant import config_entries
from homeassistant.components.mcp.const import DOMAIN
from homeassistant.const import CONF_URL
from homeassistant.components.mcp.const import (
CONF_AUTHORIZATION_URL,
CONF_TOKEN_URL,
DOMAIN,
)
from homeassistant.const import CONF_TOKEN, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow
from .conftest import TEST_API_NAME
from .conftest import (
AUTH_DOMAIN,
CLIENT_ID,
MCP_SERVER_URL,
OAUTH_AUTHORIZE_URL,
OAUTH_TOKEN_URL,
TEST_API_NAME,
)
from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
MCP_SERVER_BASE_URL = "http://1.1.1.1:8080"
OAUTH_DISCOVERY_ENDPOINT = (
f"{MCP_SERVER_BASE_URL}/.well-known/oauth-authorization-server"
)
OAUTH_SERVER_METADATA_RESPONSE = httpx.Response(
status_code=200,
text=json.dumps(
{
"authorization_endpoint": OAUTH_AUTHORIZE_URL,
"token_endpoint": OAUTH_TOKEN_URL,
}
),
)
CALLBACK_PATH = "/auth/external/callback"
OAUTH_CALLBACK_URL = f"https://example.com{CALLBACK_PATH}"
OAUTH_CODE = "abcd"
OAUTH_TOKEN_PAYLOAD = {
"refresh_token": "mock-refresh-token",
"access_token": "mock-access-token",
"type": "Bearer",
"expires_in": 60,
}
def encode_state(hass: HomeAssistant, flow_id: str) -> str:
"""Encode the OAuth JWT."""
return config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": flow_id,
"redirect_uri": OAUTH_CALLBACK_URL,
},
)
async def test_form(
@ -34,15 +84,19 @@ async def test_form(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
}
# Config entry does not have a unique id
assert result["result"]
assert result["result"].unique_id is None
assert len(mock_setup_entry.mock_calls) == 1
@ -73,7 +127,7 @@ async def test_form_mcp_client_error(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
},
)
@ -89,50 +143,18 @@ async def test_form_mcp_client_error(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
}
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(401)),
"invalid_auth",
),
],
)
async def test_form_mcp_client_error_abort(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test we handle different client library errors that end with an abort."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_mcp_client.side_effect = side_effect
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == expected_error
@pytest.mark.parametrize(
"user_input",
[
@ -165,14 +187,14 @@ async def test_input_form_validation_error(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
}
assert len(mock_setup_entry.mock_calls) == 1
@ -183,7 +205,7 @@ async def test_unique_url(
"""Test that the same url cannot be configured twice."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_URL: "http://1.1.1.1/sse"},
data={CONF_URL: MCP_SERVER_URL},
title=TEST_API_NAME,
)
config_entry.add_to_hass(hass)
@ -201,7 +223,7 @@ async def test_unique_url(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
},
)
@ -226,9 +248,409 @@ async def test_server_missing_capbilities(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_capabilities"
@respx.mock
async def test_oauth_discovery_flow_without_credentials(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
) -> None:
"""Test for an OAuth discoveryflow for an MCP server where the user has not yet entered credentials."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=OAUTH_SERVER_METADATA_RESPONSE
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
# The config flow will abort and the user will be taken to the application credentials UI
# to enter their credentials.
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_credentials"
async def perform_oauth_flow(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
result: config_entries.ConfigFlowResult,
authorize_url: str = OAUTH_AUTHORIZE_URL,
token_url: str = OAUTH_TOKEN_URL,
) -> config_entries.ConfigFlowResult:
"""Perform the common steps of the OAuth flow.
Expects to be called from the step where the user selects credentials.
"""
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": OAUTH_CALLBACK_URL,
},
)
assert result["url"] == (
f"{authorize_url}?response_type=code&client_id={CLIENT_ID}"
f"&redirect_uri={OAUTH_CALLBACK_URL}"
f"&state={state}"
)
client = await hass_client_no_auth()
resp = await client.get(f"{CALLBACK_PATH}?code={OAUTH_CODE}&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
aioclient_mock.post(
token_url,
json=OAUTH_TOKEN_PAYLOAD,
)
return result
@pytest.mark.parametrize(
("oauth_server_metadata_response", "expected_authorize_url", "expected_token_url"),
[
(OAUTH_SERVER_METADATA_RESPONSE, OAUTH_AUTHORIZE_URL, OAUTH_TOKEN_URL),
(
httpx.Response(
status_code=200,
text=json.dumps(
{
"authorization_endpoint": "/authorize-path",
"token_endpoint": "/token-path",
}
),
),
f"{MCP_SERVER_BASE_URL}/authorize-path",
f"{MCP_SERVER_BASE_URL}/token-path",
),
(
httpx.Response(status_code=404),
f"{MCP_SERVER_BASE_URL}/authorize",
f"{MCP_SERVER_BASE_URL}/token",
),
],
ids=(
"discovery",
"relative_paths",
"no_discovery_metadata",
),
)
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_authentication_flow(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
oauth_server_metadata_response: httpx.Response,
expected_authorize_url: str,
expected_token_url: str,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=oauth_server_metadata_response
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.MENU
assert result["step_id"] == "credentials_choice"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"next_step_id": "pick_implementation",
},
)
assert result["type"] is FlowResultType.EXTERNAL_STEP
result = await perform_oauth_flow(
hass,
aioclient_mock,
hass_client_no_auth,
result,
authorize_url=expected_authorize_url,
token_url=expected_token_url,
)
# Client now accepts credentials
mock_mcp_client.side_effect = None
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
data = result["data"]
token = data.pop(CONF_TOKEN)
assert data == {
"auth_implementation": AUTH_DOMAIN,
CONF_URL: MCP_SERVER_URL,
CONF_AUTHORIZATION_URL: expected_authorize_url,
CONF_TOKEN_URL: expected_token_url,
}
assert token
token.pop("expires_at")
assert token == OAUTH_TOKEN_PAYLOAD
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
"cannot_connect",
),
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
(Exception, "unknown"),
],
)
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_oauth_discovery_failure(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(side_effect=side_effect)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == expected_error
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
"cannot_connect",
),
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
(Exception, "unknown"),
],
)
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_authentication_flow_server_failure_abort(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=OAUTH_SERVER_METADATA_RESPONSE
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.MENU
assert result["step_id"] == "credentials_choice"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"next_step_id": "pick_implementation",
},
)
assert result["type"] is FlowResultType.EXTERNAL_STEP
result = await perform_oauth_flow(
hass,
aioclient_mock,
hass_client_no_auth,
result,
)
# Client fails with an error
mock_mcp_client.side_effect = side_effect
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == expected_error
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_authentication_flow_server_missing_tool_capabilities(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
# MCP Server returns 401 indicating the client needs to authenticate
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"Authentication required", request=None, response=httpx.Response(401)
)
# Prepare the OAuth Server metadata
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
return_value=OAUTH_SERVER_METADATA_RESPONSE
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: MCP_SERVER_URL,
},
)
assert result["type"] is FlowResultType.MENU
assert result["step_id"] == "credentials_choice"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
"next_step_id": "pick_implementation",
},
)
assert result["type"] is FlowResultType.EXTERNAL_STEP
result = await perform_oauth_flow(
hass,
aioclient_mock,
hass_client_no_auth,
result,
)
# Client can now authenticate
mock_mcp_client.side_effect = None
response = Mock()
response.serverInfo.name = TEST_API_NAME
response.capabilities.tools = None
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_capabilities"
@pytest.mark.usefixtures("current_request_with_host")
@respx.mock
async def test_reauth_flow(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
credential: None,
config_entry_with_auth: MockConfigEntry,
aioclient_mock: AiohttpClientMocker,
hass_client_no_auth: ClientSessionGenerator,
) -> None:
"""Test for an OAuth authentication flow for an MCP server."""
config_entry_with_auth.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
result = await perform_oauth_flow(hass, aioclient_mock, hass_client_no_auth, result)
# Verify we can connect to the server
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert config_entry_with_auth.unique_id == AUTH_DOMAIN
assert config_entry_with_auth.title == TEST_API_NAME
data = {**config_entry_with_auth.data}
token = data.pop(CONF_TOKEN)
assert data == {
"auth_implementation": AUTH_DOMAIN,
CONF_URL: MCP_SERVER_URL,
CONF_AUTHORIZATION_URL: OAUTH_AUTHORIZE_URL,
CONF_TOKEN_URL: OAUTH_TOKEN_URL,
}
assert token
token.pop("expires_at")
assert token == OAUTH_TOKEN_PAYLOAD
assert len(mock_setup_entry.mock_calls) == 1

View File

@ -76,17 +76,45 @@ async def test_init(
assert config_entry.state is ConfigEntryState.NOT_LOADED
@pytest.mark.parametrize(
("side_effect"),
[
(httpx.TimeoutException("Some timeout")),
(httpx.HTTPStatusError("", request=None, response=httpx.Response(500))),
(httpx.HTTPStatusError("", request=None, response=httpx.Response(401))),
(httpx.HTTPError("Some HTTP error")),
],
)
async def test_mcp_server_failure(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
hass: HomeAssistant,
config_entry: MockConfigEntry,
mock_mcp_client: Mock,
side_effect: Exception,
) -> None:
"""Test the integration fails to setup if the server fails initialization."""
mock_mcp_client.side_effect = side_effect
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_mcp_server_authentication_failure(
hass: HomeAssistant,
credential: None,
config_entry_with_auth: MockConfigEntry,
mock_mcp_client: Mock,
) -> None:
"""Test the integration fails to setup if the server fails authentication."""
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"", request=None, response=httpx.Response(500)
"Authentication required", request=None, response=httpx.Response(401)
)
with patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1):
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
await hass.config_entries.async_setup(config_entry_with_auth.entry_id)
assert config_entry_with_auth.state is ConfigEntryState.SETUP_ERROR
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
assert flows[0]["step_id"] == "reauth_confirm"
async def test_list_tools_failure(