core/homeassistant/components/mcp/config_flow.py
Allen Porter 0c4cb27fe9
Add OAuth support for Model Context Protocol (mcp) integration (#141874)
* Add authentication support for Model Context Protocol (mcp) integration

* Update homeassistant/components/mcp/application_credentials.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Handle MCP servers with ports

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
2025-03-30 23:14:52 -04:00

307 lines
12 KiB
Python

"""Config flow for the Model Context Protocol integration."""
from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any, cast
import httpx
import voluptuous as vol
from yarl import URL
from homeassistant.components.application_credentials import AuthorizationServer
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
from homeassistant.const import CONF_TOKEN, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.config_entry_oauth2_flow import (
AbstractOAuth2FlowHandler,
async_get_implementations,
)
from . import async_get_config_entry_implementation
from .application_credentials import authorization_server_context
from .const import CONF_ACCESS_TOKEN, CONF_AUTHORIZATION_URL, CONF_TOKEN_URL, DOMAIN
from .coordinator import TokenManager, mcp_client
_LOGGER = logging.getLogger(__name__)
STEP_USER_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_URL): str,
}
)
# OAuth server discovery endpoint for rfc8414
OAUTH_DISCOVERY_ENDPOINT = ".well-known/oauth-authorization-server"
MCP_DISCOVERY_HEADERS = {
"MCP-Protocol-Version": "2025-03-26",
}
async def async_discover_oauth_config(
hass: HomeAssistant, mcp_server_url: str
) -> AuthorizationServer:
"""Discover the OAuth configuration for the MCP server.
This implements the functionality in the MCP spec for discovery. If the MCP server URL
is https://api.example.com/v1/mcp, then:
- The authorization base URL is https://api.example.com
- The metadata endpoint MUST be at https://api.example.com/.well-known/oauth-authorization-server
- For servers that do not implement OAuth 2.0 Authorization Server Metadata, the client uses
default paths relative to the authorization base URL.
"""
parsed_url = URL(mcp_server_url)
discovery_endpoint = str(parsed_url.with_path(OAUTH_DISCOVERY_ENDPOINT))
try:
async with httpx.AsyncClient(headers=MCP_DISCOVERY_HEADERS) as client:
response = await client.get(discovery_endpoint)
response.raise_for_status()
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)
raise TimeoutConnectError from error
except httpx.HTTPStatusError as error:
if error.response.status_code == 404:
_LOGGER.info("Authorization Server Metadata not found, using default paths")
return AuthorizationServer(
authorize_url=str(parsed_url.with_path("/authorize")),
token_url=str(parsed_url.with_path("/token")),
)
raise CannotConnect from error
except httpx.HTTPError as error:
_LOGGER.info("Cannot discover OAuth configuration: %s", error)
raise CannotConnect from error
data = response.json()
authorize_url = data["authorization_endpoint"]
token_url = data["token_endpoint"]
if authorize_url.startswith("/"):
authorize_url = str(parsed_url.with_path(authorize_url))
if token_url.startswith("/"):
token_url = str(parsed_url.with_path(token_url))
return AuthorizationServer(
authorize_url=authorize_url,
token_url=token_url,
)
async def validate_input(
hass: HomeAssistant, data: dict[str, Any], token_manager: TokenManager | None = None
) -> dict[str, Any]:
"""Validate the user input and connect to the MCP server."""
url = data[CONF_URL]
try:
cv.url(url) # Cannot be added to schema directly
except vol.Invalid as error:
raise InvalidUrl from error
try:
async with mcp_client(url, token_manager=token_manager) as session:
response = await session.initialize()
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)
raise TimeoutConnectError from error
except httpx.HTTPStatusError as error:
_LOGGER.info("Cannot connect to MCP server: %s", error)
if error.response.status_code == 401:
raise InvalidAuth from error
raise CannotConnect from error
except httpx.HTTPError as error:
_LOGGER.info("Cannot connect to MCP server: %s", error)
raise CannotConnect from error
if not response.capabilities.tools:
raise MissingCapabilities(
f"MCP Server {url} does not support 'Tools' capability"
)
return {"title": response.serverInfo.name}
class ModelContextProtocolConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
"""Handle a config flow for Model Context Protocol."""
VERSION = 1
DOMAIN = DOMAIN
logger = _LOGGER
def __init__(self) -> None:
"""Initialize the config flow."""
super().__init__()
self.data: dict[str, Any] = {}
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
errors: dict[str, str] = {}
if user_input is not None:
try:
info = await validate_input(self.hass, user_input)
except InvalidUrl:
errors[CONF_URL] = "invalid_url"
except TimeoutConnectError:
errors["base"] = "timeout_connect"
except CannotConnect:
errors["base"] = "cannot_connect"
except InvalidAuth:
self.data[CONF_URL] = user_input[CONF_URL]
return await self.async_step_auth_discovery()
except MissingCapabilities:
return self.async_abort(reason="missing_capabilities")
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
self._async_abort_entries_match({CONF_URL: user_input[CONF_URL]})
return self.async_create_entry(title=info["title"], data=user_input)
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)
async def async_step_auth_discovery(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the OAuth server discovery step.
Since this OAuth server requires authentication, this step will attempt
to find the OAuth medata then run the OAuth authentication flow.
"""
try:
authorization_server = await async_discover_oauth_config(
self.hass, self.data[CONF_URL]
)
except TimeoutConnectError:
return self.async_abort(reason="timeout_connect")
except CannotConnect:
return self.async_abort(reason="cannot_connect")
except Exception:
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
else:
_LOGGER.info("OAuth configuration: %s", authorization_server)
self.data.update(
{
CONF_AUTHORIZATION_URL: authorization_server.authorize_url,
CONF_TOKEN_URL: authorization_server.token_url,
}
)
return await self.async_step_credentials_choice()
def authorization_server(self) -> AuthorizationServer:
"""Return the authorization server provided by the MCP server."""
return AuthorizationServer(
self.data[CONF_AUTHORIZATION_URL],
self.data[CONF_TOKEN_URL],
)
async def async_step_credentials_choice(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step to ask they user if they would like to add credentials.
This is needed since we can't automatically assume existing credentials
should be used given they may be for another existing server.
"""
with authorization_server_context(self.authorization_server()):
if not await async_get_implementations(self.hass, self.DOMAIN):
return await self.async_step_new_credentials()
return self.async_show_menu(
step_id="credentials_choice",
menu_options=["pick_implementation", "new_credentials"],
)
async def async_step_new_credentials(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step to take the frontend flow to enter new credentials."""
return self.async_abort(reason="missing_credentials")
async def async_step_pick_implementation(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the pick implementation step.
This exists to dynamically set application credentials Authorization Server
based on the values form the OAuth discovery step.
"""
with authorization_server_context(self.authorization_server()):
return await super().async_step_pick_implementation(user_input)
async def async_oauth_create_entry(self, data: dict) -> ConfigFlowResult:
"""Create an entry for the flow.
Ok to override if you want to fetch extra info or even add another step.
"""
config_entry_data = {
**self.data,
**data,
}
async def token_manager() -> str:
return cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
try:
info = await validate_input(self.hass, config_entry_data, token_manager)
except TimeoutConnectError:
return self.async_abort(reason="timeout_connect")
except CannotConnect:
return self.async_abort(reason="cannot_connect")
except MissingCapabilities:
return self.async_abort(reason="missing_capabilities")
except Exception:
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
# Unique id based on the application credentials OAuth Client ID
if self.source == SOURCE_REAUTH:
return self.async_update_reload_and_abort(
self._get_reauth_entry(), data=config_entry_data
)
await self.async_set_unique_id(config_entry_data["auth_implementation"])
return self.async_create_entry(
title=info["title"],
data=config_entry_data,
)
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Perform reauth upon an API authentication error."""
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: Mapping[str, Any] | None = None
) -> ConfigFlowResult:
"""Confirm reauth dialog."""
if user_input is None:
return self.async_show_form(step_id="reauth_confirm")
config_entry = self._get_reauth_entry()
self.data = {**config_entry.data}
self.flow_impl = await async_get_config_entry_implementation( # type: ignore[assignment]
self.hass, config_entry
)
return await self.async_step_auth()
class InvalidUrl(HomeAssistantError):
"""Error to indicate the URL format is invalid."""
class CannotConnect(HomeAssistantError):
"""Error to indicate we cannot connect."""
class TimeoutConnectError(HomeAssistantError):
"""Error to indicate we cannot connect."""
class InvalidAuth(HomeAssistantError):
"""Error to indicate there is invalid auth."""
class MissingCapabilities(HomeAssistantError):
"""Error to indicate that the MCP server is missing required capabilities."""