mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Add OAuth support for Model Context Protocol (mcp) integration (#141874)
* Add authentication support for Model Context Protocol (mcp) integration * Update homeassistant/components/mcp/application_credentials.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Handle MCP servers with ports --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
1639163c2e
commit
0c4cb27fe9
@ -3,12 +3,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from homeassistant.components.application_credentials import AuthorizationServer
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers import config_entry_oauth2_flow, llm
|
||||
|
||||
from .const import DOMAIN
|
||||
from .coordinator import ModelContextProtocolCoordinator
|
||||
from .application_credentials import authorization_server_context
|
||||
from .const import CONF_ACCESS_TOKEN, CONF_AUTHORIZATION_URL, CONF_TOKEN_URL, DOMAIN
|
||||
from .coordinator import ModelContextProtocolCoordinator, TokenManager
|
||||
from .types import ModelContextProtocolConfigEntry
|
||||
|
||||
__all__ = [
|
||||
@ -20,11 +23,45 @@ __all__ = [
|
||||
API_PROMPT = "The following tools are available from a remote server named {name}."
|
||||
|
||||
|
||||
async def async_get_config_entry_implementation(
|
||||
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
|
||||
) -> config_entry_oauth2_flow.AbstractOAuth2Implementation | None:
|
||||
"""OAuth implementation for the config entry."""
|
||||
if "auth_implementation" not in entry.data:
|
||||
return None
|
||||
with authorization_server_context(
|
||||
AuthorizationServer(
|
||||
authorize_url=entry.data[CONF_AUTHORIZATION_URL],
|
||||
token_url=entry.data[CONF_TOKEN_URL],
|
||||
)
|
||||
):
|
||||
return await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||
hass, entry
|
||||
)
|
||||
|
||||
|
||||
async def _create_token_manager(
|
||||
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
|
||||
) -> TokenManager | None:
|
||||
"""Create a OAuth token manager for the config entry if the server requires authentication."""
|
||||
if not (implementation := await async_get_config_entry_implementation(hass, entry)):
|
||||
return None
|
||||
|
||||
session = config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation)
|
||||
|
||||
async def token_manager() -> str:
|
||||
await session.async_ensure_token_valid()
|
||||
return cast(str, session.token[CONF_ACCESS_TOKEN])
|
||||
|
||||
return token_manager
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
|
||||
) -> bool:
|
||||
"""Set up Model Context Protocol from a config entry."""
|
||||
coordinator = ModelContextProtocolCoordinator(hass, entry)
|
||||
token_manager = await _create_token_manager(hass, entry)
|
||||
coordinator = ModelContextProtocolCoordinator(hass, entry, token_manager)
|
||||
await coordinator.async_config_entry_first_refresh()
|
||||
|
||||
unsub = llm.async_register_api(
|
||||
|
35
homeassistant/components/mcp/application_credentials.py
Normal file
35
homeassistant/components/mcp/application_credentials.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Application credentials platform for Model Context Protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
import contextvars
|
||||
|
||||
from homeassistant.components.application_credentials import AuthorizationServer
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
CONF_ACTIVE_AUTHORIZATION_SERVER = "active_authorization_server"
|
||||
|
||||
_mcp_context: contextvars.ContextVar[AuthorizationServer] = contextvars.ContextVar(
|
||||
"mcp_authorization_server_context"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def authorization_server_context(
|
||||
authorization_server: AuthorizationServer,
|
||||
) -> Generator[None]:
|
||||
"""Context manager for setting the active authorization server."""
|
||||
token = _mcp_context.set(authorization_server)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_mcp_context.reset(token)
|
||||
|
||||
|
||||
async def async_get_authorization_server(hass: HomeAssistant) -> AuthorizationServer:
|
||||
"""Return authorization server, for the default auth implementation."""
|
||||
if _mcp_context.get() is None:
|
||||
raise RuntimeError("No MCP authorization server set in context")
|
||||
return _mcp_context.get()
|
@ -2,20 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
import voluptuous as vol
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.components.application_credentials import AuthorizationServer
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
|
||||
from homeassistant.const import CONF_TOKEN, CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.config_entry_oauth2_flow import (
|
||||
AbstractOAuth2FlowHandler,
|
||||
async_get_implementations,
|
||||
)
|
||||
|
||||
from .const import DOMAIN
|
||||
from .coordinator import mcp_client
|
||||
from . import async_get_config_entry_implementation
|
||||
from .application_credentials import authorization_server_context
|
||||
from .const import CONF_ACCESS_TOKEN, CONF_AUTHORIZATION_URL, CONF_TOKEN_URL, DOMAIN
|
||||
from .coordinator import TokenManager, mcp_client
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -25,8 +34,62 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
# OAuth server discovery endpoint for rfc8414
|
||||
OAUTH_DISCOVERY_ENDPOINT = ".well-known/oauth-authorization-server"
|
||||
MCP_DISCOVERY_HEADERS = {
|
||||
"MCP-Protocol-Version": "2025-03-26",
|
||||
}
|
||||
|
||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
async def async_discover_oauth_config(
|
||||
hass: HomeAssistant, mcp_server_url: str
|
||||
) -> AuthorizationServer:
|
||||
"""Discover the OAuth configuration for the MCP server.
|
||||
|
||||
This implements the functionality in the MCP spec for discovery. If the MCP server URL
|
||||
is https://api.example.com/v1/mcp, then:
|
||||
- The authorization base URL is https://api.example.com
|
||||
- The metadata endpoint MUST be at https://api.example.com/.well-known/oauth-authorization-server
|
||||
- For servers that do not implement OAuth 2.0 Authorization Server Metadata, the client uses
|
||||
default paths relative to the authorization base URL.
|
||||
"""
|
||||
parsed_url = URL(mcp_server_url)
|
||||
discovery_endpoint = str(parsed_url.with_path(OAUTH_DISCOVERY_ENDPOINT))
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=MCP_DISCOVERY_HEADERS) as client:
|
||||
response = await client.get(discovery_endpoint)
|
||||
response.raise_for_status()
|
||||
except httpx.TimeoutException as error:
|
||||
_LOGGER.info("Timeout connecting to MCP server: %s", error)
|
||||
raise TimeoutConnectError from error
|
||||
except httpx.HTTPStatusError as error:
|
||||
if error.response.status_code == 404:
|
||||
_LOGGER.info("Authorization Server Metadata not found, using default paths")
|
||||
return AuthorizationServer(
|
||||
authorize_url=str(parsed_url.with_path("/authorize")),
|
||||
token_url=str(parsed_url.with_path("/token")),
|
||||
)
|
||||
raise CannotConnect from error
|
||||
except httpx.HTTPError as error:
|
||||
_LOGGER.info("Cannot discover OAuth configuration: %s", error)
|
||||
raise CannotConnect from error
|
||||
|
||||
data = response.json()
|
||||
authorize_url = data["authorization_endpoint"]
|
||||
token_url = data["token_endpoint"]
|
||||
if authorize_url.startswith("/"):
|
||||
authorize_url = str(parsed_url.with_path(authorize_url))
|
||||
if token_url.startswith("/"):
|
||||
token_url = str(parsed_url.with_path(token_url))
|
||||
return AuthorizationServer(
|
||||
authorize_url=authorize_url,
|
||||
token_url=token_url,
|
||||
)
|
||||
|
||||
|
||||
async def validate_input(
|
||||
hass: HomeAssistant, data: dict[str, Any], token_manager: TokenManager | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate the user input and connect to the MCP server."""
|
||||
url = data[CONF_URL]
|
||||
try:
|
||||
@ -34,7 +97,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
|
||||
except vol.Invalid as error:
|
||||
raise InvalidUrl from error
|
||||
try:
|
||||
async with mcp_client(url) as session:
|
||||
async with mcp_client(url, token_manager=token_manager) as session:
|
||||
response = await session.initialize()
|
||||
except httpx.TimeoutException as error:
|
||||
_LOGGER.info("Timeout connecting to MCP server: %s", error)
|
||||
@ -56,10 +119,17 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
|
||||
return {"title": response.serverInfo.name}
|
||||
|
||||
|
||||
class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
class ModelContextProtocolConfigFlow(AbstractOAuth2FlowHandler, domain=DOMAIN):
|
||||
"""Handle a config flow for Model Context Protocol."""
|
||||
|
||||
VERSION = 1
|
||||
DOMAIN = DOMAIN
|
||||
logger = _LOGGER
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the config flow."""
|
||||
super().__init__()
|
||||
self.data: dict[str, Any] = {}
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@ -76,7 +146,8 @@ class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
except CannotConnect:
|
||||
errors["base"] = "cannot_connect"
|
||||
except InvalidAuth:
|
||||
return self.async_abort(reason="invalid_auth")
|
||||
self.data[CONF_URL] = user_input[CONF_URL]
|
||||
return await self.async_step_auth_discovery()
|
||||
except MissingCapabilities:
|
||||
return self.async_abort(reason="missing_capabilities")
|
||||
except Exception:
|
||||
@ -90,6 +161,130 @@ class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
async def async_step_auth_discovery(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the OAuth server discovery step.
|
||||
|
||||
Since this OAuth server requires authentication, this step will attempt
|
||||
to find the OAuth medata then run the OAuth authentication flow.
|
||||
"""
|
||||
try:
|
||||
authorization_server = await async_discover_oauth_config(
|
||||
self.hass, self.data[CONF_URL]
|
||||
)
|
||||
except TimeoutConnectError:
|
||||
return self.async_abort(reason="timeout_connect")
|
||||
except CannotConnect:
|
||||
return self.async_abort(reason="cannot_connect")
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
return self.async_abort(reason="unknown")
|
||||
else:
|
||||
_LOGGER.info("OAuth configuration: %s", authorization_server)
|
||||
self.data.update(
|
||||
{
|
||||
CONF_AUTHORIZATION_URL: authorization_server.authorize_url,
|
||||
CONF_TOKEN_URL: authorization_server.token_url,
|
||||
}
|
||||
)
|
||||
return await self.async_step_credentials_choice()
|
||||
|
||||
def authorization_server(self) -> AuthorizationServer:
|
||||
"""Return the authorization server provided by the MCP server."""
|
||||
return AuthorizationServer(
|
||||
self.data[CONF_AUTHORIZATION_URL],
|
||||
self.data[CONF_TOKEN_URL],
|
||||
)
|
||||
|
||||
async def async_step_credentials_choice(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Step to ask they user if they would like to add credentials.
|
||||
|
||||
This is needed since we can't automatically assume existing credentials
|
||||
should be used given they may be for another existing server.
|
||||
"""
|
||||
with authorization_server_context(self.authorization_server()):
|
||||
if not await async_get_implementations(self.hass, self.DOMAIN):
|
||||
return await self.async_step_new_credentials()
|
||||
return self.async_show_menu(
|
||||
step_id="credentials_choice",
|
||||
menu_options=["pick_implementation", "new_credentials"],
|
||||
)
|
||||
|
||||
async def async_step_new_credentials(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Step to take the frontend flow to enter new credentials."""
|
||||
return self.async_abort(reason="missing_credentials")
|
||||
|
||||
async def async_step_pick_implementation(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the pick implementation step.
|
||||
|
||||
This exists to dynamically set application credentials Authorization Server
|
||||
based on the values form the OAuth discovery step.
|
||||
"""
|
||||
with authorization_server_context(self.authorization_server()):
|
||||
return await super().async_step_pick_implementation(user_input)
|
||||
|
||||
async def async_oauth_create_entry(self, data: dict) -> ConfigFlowResult:
|
||||
"""Create an entry for the flow.
|
||||
|
||||
Ok to override if you want to fetch extra info or even add another step.
|
||||
"""
|
||||
config_entry_data = {
|
||||
**self.data,
|
||||
**data,
|
||||
}
|
||||
|
||||
async def token_manager() -> str:
|
||||
return cast(str, data[CONF_TOKEN][CONF_ACCESS_TOKEN])
|
||||
|
||||
try:
|
||||
info = await validate_input(self.hass, config_entry_data, token_manager)
|
||||
except TimeoutConnectError:
|
||||
return self.async_abort(reason="timeout_connect")
|
||||
except CannotConnect:
|
||||
return self.async_abort(reason="cannot_connect")
|
||||
except MissingCapabilities:
|
||||
return self.async_abort(reason="missing_capabilities")
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
# Unique id based on the application credentials OAuth Client ID
|
||||
if self.source == SOURCE_REAUTH:
|
||||
return self.async_update_reload_and_abort(
|
||||
self._get_reauth_entry(), data=config_entry_data
|
||||
)
|
||||
await self.async_set_unique_id(config_entry_data["auth_implementation"])
|
||||
return self.async_create_entry(
|
||||
title=info["title"],
|
||||
data=config_entry_data,
|
||||
)
|
||||
|
||||
async def async_step_reauth(
|
||||
self, entry_data: Mapping[str, Any]
|
||||
) -> ConfigFlowResult:
|
||||
"""Perform reauth upon an API authentication error."""
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: Mapping[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Confirm reauth dialog."""
|
||||
if user_input is None:
|
||||
return self.async_show_form(step_id="reauth_confirm")
|
||||
config_entry = self._get_reauth_entry()
|
||||
self.data = {**config_entry.data}
|
||||
self.flow_impl = await async_get_config_entry_implementation( # type: ignore[assignment]
|
||||
self.hass, config_entry
|
||||
)
|
||||
return await self.async_step_auth()
|
||||
|
||||
|
||||
class InvalidUrl(HomeAssistantError):
|
||||
"""Error to indicate the URL format is invalid."""
|
||||
|
@ -1,3 +1,7 @@
|
||||
"""Constants for the Model Context Protocol integration."""
|
||||
|
||||
DOMAIN = "mcp"
|
||||
|
||||
CONF_ACCESS_TOKEN = "access_token"
|
||||
CONF_AUTHORIZATION_URL = "authorization_url"
|
||||
CONF_TOKEN_URL = "token_url"
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Types for the Model Context Protocol integration."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
import datetime
|
||||
import logging
|
||||
@ -15,7 +15,7 @@ from voluptuous_openapi import convert_to_voluptuous
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
@ -27,16 +27,28 @@ _LOGGER = logging.getLogger(__name__)
|
||||
UPDATE_INTERVAL = datetime.timedelta(minutes=30)
|
||||
TIMEOUT = 10
|
||||
|
||||
TokenManager = Callable[[], Awaitable[str]]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def mcp_client(url: str) -> AsyncGenerator[ClientSession]:
|
||||
async def mcp_client(
|
||||
url: str,
|
||||
token_manager: TokenManager | None = None,
|
||||
) -> AsyncGenerator[ClientSession]:
|
||||
"""Create a server-sent event MCP client.
|
||||
|
||||
This is an asynccontext manager that exists to wrap other async context managers
|
||||
so that the coordinator has a single object to manage.
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
if token_manager is not None:
|
||||
token = await token_manager()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
try:
|
||||
async with sse_client(url=url) as streams, ClientSession(*streams) as session:
|
||||
async with (
|
||||
sse_client(url=url, headers=headers) as streams,
|
||||
ClientSession(*streams) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
yield session
|
||||
except ExceptionGroup as err:
|
||||
@ -53,12 +65,14 @@ class ModelContextProtocolTool(llm.Tool):
|
||||
description: str | None,
|
||||
parameters: vol.Schema,
|
||||
server_url: str,
|
||||
token_manager: TokenManager | None = None,
|
||||
) -> None:
|
||||
"""Initialize the tool."""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.server_url = server_url
|
||||
self.token_manager = token_manager
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
@ -69,7 +83,7 @@ class ModelContextProtocolTool(llm.Tool):
|
||||
"""Call the tool."""
|
||||
try:
|
||||
async with asyncio.timeout(TIMEOUT):
|
||||
async with mcp_client(self.server_url) as session:
|
||||
async with mcp_client(self.server_url, self.token_manager) as session:
|
||||
result = await session.call_tool(
|
||||
tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
@ -87,7 +101,12 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
|
||||
|
||||
config_entry: ConfigEntry
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
token_manager: TokenManager | None = None,
|
||||
) -> None:
|
||||
"""Initialize ModelContextProtocolCoordinator."""
|
||||
super().__init__(
|
||||
hass,
|
||||
@ -96,6 +115,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
|
||||
config_entry=config_entry,
|
||||
update_interval=UPDATE_INTERVAL,
|
||||
)
|
||||
self.token_manager = token_manager
|
||||
|
||||
async def _async_update_data(self) -> list[llm.Tool]:
|
||||
"""Fetch data from API endpoint.
|
||||
@ -105,11 +125,20 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
|
||||
"""
|
||||
try:
|
||||
async with asyncio.timeout(TIMEOUT):
|
||||
async with mcp_client(self.config_entry.data[CONF_URL]) as session:
|
||||
async with mcp_client(
|
||||
self.config_entry.data[CONF_URL], self.token_manager
|
||||
) as session:
|
||||
result = await session.list_tools()
|
||||
except TimeoutError as error:
|
||||
_LOGGER.debug("Timeout when listing tools: %s", error)
|
||||
raise UpdateFailed(f"Timeout when listing tools: {error}") from error
|
||||
except httpx.HTTPStatusError as error:
|
||||
_LOGGER.debug("Error communicating with API: %s", error)
|
||||
if error.response.status_code == 401 and self.token_manager is not None:
|
||||
raise ConfigEntryAuthFailed(
|
||||
"The MCP server requires authentication"
|
||||
) from error
|
||||
raise UpdateFailed(f"Error communicating with API: {error}") from error
|
||||
except httpx.HTTPError as err:
|
||||
_LOGGER.debug("Error communicating with API: %s", err)
|
||||
raise UpdateFailed(f"Error communicating with API: {err}") from err
|
||||
@ -129,6 +158,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
|
||||
tool.description,
|
||||
parameters,
|
||||
self.config_entry.data[CONF_URL],
|
||||
self.token_manager,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
@ -3,6 +3,7 @@
|
||||
"name": "Model Context Protocol",
|
||||
"codeowners": ["@allenporter"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["application_credentials"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/mcp",
|
||||
"iot_class": "local_polling",
|
||||
"quality_scale": "silver",
|
||||
|
@ -44,9 +44,7 @@ rules:
|
||||
parallel-updates:
|
||||
status: exempt
|
||||
comment: Integration does not have platforms.
|
||||
reauthentication-flow:
|
||||
status: exempt
|
||||
comment: Integration does not support authentication.
|
||||
reauthentication-flow: done
|
||||
test-coverage: done
|
||||
|
||||
# Gold
|
||||
|
@ -8,6 +8,15 @@
|
||||
"data_description": {
|
||||
"url": "The remote MCP server URL for the SSE endpoint, for example http://example/sse"
|
||||
}
|
||||
},
|
||||
"pick_implementation": {
|
||||
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]",
|
||||
"data": {
|
||||
"implementation": "Credentials"
|
||||
},
|
||||
"data_description": {
|
||||
"implementation": "The credentials to use for the OAuth2 flow"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
@ -17,9 +26,15 @@
|
||||
"invalid_url": "Must be a valid MCP server URL e.g. https://example.com/sse"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
|
||||
"missing_capabilities": "The MCP server does not support a required capability (Tools)",
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
||||
"missing_credentials": "[%key:common::config_flow::abort::oauth2_missing_credentials%]",
|
||||
"reauth_account_mismatch": "The authenticated user does not match the MCP Server user that needed re-authentication.",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
|
||||
"timeout_connect": "[%key:common::config_flow::error::timeout_connect%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ APPLICATION_CREDENTIALS = [
|
||||
"iotty",
|
||||
"lametric",
|
||||
"lyric",
|
||||
"mcp",
|
||||
"microbees",
|
||||
"monzo",
|
||||
"myuplink",
|
||||
|
@ -1,17 +1,34 @@
|
||||
"""Common fixtures for the Model Context Protocol tests."""
|
||||
|
||||
from collections.abc import Generator
|
||||
import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.mcp.const import DOMAIN
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.components.application_credentials import (
|
||||
ClientCredential,
|
||||
async_import_client_credential,
|
||||
)
|
||||
from homeassistant.components.mcp.const import (
|
||||
CONF_ACCESS_TOKEN,
|
||||
CONF_AUTHORIZATION_URL,
|
||||
CONF_TOKEN_URL,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.const import CONF_TOKEN, CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
TEST_API_NAME = "Memory Server"
|
||||
MCP_SERVER_URL = "http://1.1.1.1:8080/sse"
|
||||
CLIENT_ID = "test-client-id"
|
||||
CLIENT_SECRET = "test-client-secret"
|
||||
AUTH_DOMAIN = "some-auth-domain"
|
||||
OAUTH_AUTHORIZE_URL = "https://example-auth-server.com/authorize-path"
|
||||
OAUTH_TOKEN_URL = "https://example-auth-server.com/token-path"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -29,6 +46,7 @@ def mock_mcp_client() -> Generator[AsyncMock]:
|
||||
with (
|
||||
patch("homeassistant.components.mcp.coordinator.sse_client"),
|
||||
patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session,
|
||||
patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1),
|
||||
):
|
||||
yield mock_session.return_value.__aenter__
|
||||
|
||||
@ -43,3 +61,47 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
return config_entry
|
||||
|
||||
|
||||
@pytest.fixture(name="credential")
|
||||
async def mock_credential(hass: HomeAssistant) -> None:
|
||||
"""Fixture that provides the ClientCredential for the test."""
|
||||
assert await async_setup_component(hass, "application_credentials", {})
|
||||
await async_import_client_credential(
|
||||
hass,
|
||||
DOMAIN,
|
||||
ClientCredential(CLIENT_ID, CLIENT_SECRET),
|
||||
AUTH_DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="config_entry_token_expiration")
|
||||
def mock_config_entry_token_expiration() -> datetime.datetime:
|
||||
"""Fixture to mock the token expiration."""
|
||||
return datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)
|
||||
|
||||
|
||||
@pytest.fixture(name="config_entry_with_auth")
|
||||
def mock_config_entry_with_auth(
|
||||
hass: HomeAssistant,
|
||||
config_entry_token_expiration: datetime.datetime,
|
||||
) -> MockConfigEntry:
|
||||
"""Fixture to load the integration with authentication."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
unique_id=AUTH_DOMAIN,
|
||||
data={
|
||||
"auth_implementation": AUTH_DOMAIN,
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
CONF_AUTHORIZATION_URL: OAUTH_AUTHORIZE_URL,
|
||||
CONF_TOKEN_URL: OAUTH_TOKEN_URL,
|
||||
CONF_TOKEN: {
|
||||
CONF_ACCESS_TOKEN: "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_at": config_entry_token_expiration.timestamp(),
|
||||
},
|
||||
},
|
||||
title=TEST_API_NAME,
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
return config_entry
|
||||
|
@ -1,20 +1,70 @@
|
||||
"""Test the Model Context Protocol config flow."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.mcp.const import DOMAIN
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.components.mcp.const import (
|
||||
CONF_AUTHORIZATION_URL,
|
||||
CONF_TOKEN_URL,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.const import CONF_TOKEN, CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from .conftest import TEST_API_NAME
|
||||
from .conftest import (
|
||||
AUTH_DOMAIN,
|
||||
CLIENT_ID,
|
||||
MCP_SERVER_URL,
|
||||
OAUTH_AUTHORIZE_URL,
|
||||
OAUTH_TOKEN_URL,
|
||||
TEST_API_NAME,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
MCP_SERVER_BASE_URL = "http://1.1.1.1:8080"
|
||||
OAUTH_DISCOVERY_ENDPOINT = (
|
||||
f"{MCP_SERVER_BASE_URL}/.well-known/oauth-authorization-server"
|
||||
)
|
||||
OAUTH_SERVER_METADATA_RESPONSE = httpx.Response(
|
||||
status_code=200,
|
||||
text=json.dumps(
|
||||
{
|
||||
"authorization_endpoint": OAUTH_AUTHORIZE_URL,
|
||||
"token_endpoint": OAUTH_TOKEN_URL,
|
||||
}
|
||||
),
|
||||
)
|
||||
CALLBACK_PATH = "/auth/external/callback"
|
||||
OAUTH_CALLBACK_URL = f"https://example.com{CALLBACK_PATH}"
|
||||
OAUTH_CODE = "abcd"
|
||||
OAUTH_TOKEN_PAYLOAD = {
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
|
||||
|
||||
def encode_state(hass: HomeAssistant, flow_id: str) -> str:
|
||||
"""Encode the OAuth JWT."""
|
||||
return config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": flow_id,
|
||||
"redirect_uri": OAUTH_CALLBACK_URL,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_form(
|
||||
@ -34,15 +84,19 @@ async def test_form(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == TEST_API_NAME
|
||||
assert result["data"] == {
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
}
|
||||
# Config entry does not have a unique id
|
||||
assert result["result"]
|
||||
assert result["result"].unique_id is None
|
||||
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
@ -73,7 +127,7 @@ async def test_form_mcp_client_error(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
@ -89,50 +143,18 @@ async def test_form_mcp_client_error(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == TEST_API_NAME
|
||||
assert result["data"] == {
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "expected_error"),
|
||||
[
|
||||
(
|
||||
httpx.HTTPStatusError("", request=None, response=httpx.Response(401)),
|
||||
"invalid_auth",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_form_mcp_client_error_abort(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
side_effect: Exception,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test we handle different client library errors that end with an abort."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
mock_mcp_client.side_effect = side_effect
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == expected_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_input",
|
||||
[
|
||||
@ -165,14 +187,14 @@ async def test_input_form_validation_error(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == TEST_API_NAME
|
||||
assert result["data"] == {
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@ -183,7 +205,7 @@ async def test_unique_url(
|
||||
"""Test that the same url cannot be configured twice."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_URL: "http://1.1.1.1/sse"},
|
||||
data={CONF_URL: MCP_SERVER_URL},
|
||||
title=TEST_API_NAME,
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
@ -201,7 +223,7 @@ async def test_unique_url(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
@ -226,9 +248,409 @@ async def test_server_missing_capbilities(
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: "http://1.1.1.1/sse",
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "missing_capabilities"
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_oauth_discovery_flow_without_credentials(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
) -> None:
|
||||
"""Test for an OAuth discoveryflow for an MCP server where the user has not yet entered credentials."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
# MCP Server returns 401 indicating the client needs to authenticate
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"Authentication required", request=None, response=httpx.Response(401)
|
||||
)
|
||||
# Prepare the OAuth Server metadata
|
||||
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
|
||||
return_value=OAUTH_SERVER_METADATA_RESPONSE
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
|
||||
# The config flow will abort and the user will be taken to the application credentials UI
|
||||
# to enter their credentials.
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "missing_credentials"
|
||||
|
||||
|
||||
async def perform_oauth_flow(
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
result: config_entries.ConfigFlowResult,
|
||||
authorize_url: str = OAUTH_AUTHORIZE_URL,
|
||||
token_url: str = OAUTH_TOKEN_URL,
|
||||
) -> config_entries.ConfigFlowResult:
|
||||
"""Perform the common steps of the OAuth flow.
|
||||
|
||||
Expects to be called from the step where the user selects credentials.
|
||||
"""
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": OAUTH_CALLBACK_URL,
|
||||
},
|
||||
)
|
||||
assert result["url"] == (
|
||||
f"{authorize_url}?response_type=code&client_id={CLIENT_ID}"
|
||||
f"&redirect_uri={OAUTH_CALLBACK_URL}"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
resp = await client.get(f"{CALLBACK_PATH}?code={OAUTH_CODE}&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
aioclient_mock.post(
|
||||
token_url,
|
||||
json=OAUTH_TOKEN_PAYLOAD,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("oauth_server_metadata_response", "expected_authorize_url", "expected_token_url"),
|
||||
[
|
||||
(OAUTH_SERVER_METADATA_RESPONSE, OAUTH_AUTHORIZE_URL, OAUTH_TOKEN_URL),
|
||||
(
|
||||
httpx.Response(
|
||||
status_code=200,
|
||||
text=json.dumps(
|
||||
{
|
||||
"authorization_endpoint": "/authorize-path",
|
||||
"token_endpoint": "/token-path",
|
||||
}
|
||||
),
|
||||
),
|
||||
f"{MCP_SERVER_BASE_URL}/authorize-path",
|
||||
f"{MCP_SERVER_BASE_URL}/token-path",
|
||||
),
|
||||
(
|
||||
httpx.Response(status_code=404),
|
||||
f"{MCP_SERVER_BASE_URL}/authorize",
|
||||
f"{MCP_SERVER_BASE_URL}/token",
|
||||
),
|
||||
],
|
||||
ids=(
|
||||
"discovery",
|
||||
"relative_paths",
|
||||
"no_discovery_metadata",
|
||||
),
|
||||
)
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@respx.mock
|
||||
async def test_authentication_flow(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
credential: None,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
oauth_server_metadata_response: httpx.Response,
|
||||
expected_authorize_url: str,
|
||||
expected_token_url: str,
|
||||
) -> None:
|
||||
"""Test for an OAuth authentication flow for an MCP server."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
# MCP Server returns 401 indicating the client needs to authenticate
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"Authentication required", request=None, response=httpx.Response(401)
|
||||
)
|
||||
# Prepare the OAuth Server metadata
|
||||
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
|
||||
return_value=oauth_server_metadata_response
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.MENU
|
||||
assert result["step_id"] == "credentials_choice"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"next_step_id": "pick_implementation",
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.EXTERNAL_STEP
|
||||
result = await perform_oauth_flow(
|
||||
hass,
|
||||
aioclient_mock,
|
||||
hass_client_no_auth,
|
||||
result,
|
||||
authorize_url=expected_authorize_url,
|
||||
token_url=expected_token_url,
|
||||
)
|
||||
|
||||
# Client now accepts credentials
|
||||
mock_mcp_client.side_effect = None
|
||||
response = Mock()
|
||||
response.serverInfo.name = TEST_API_NAME
|
||||
mock_mcp_client.return_value.initialize.return_value = response
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == TEST_API_NAME
|
||||
data = result["data"]
|
||||
token = data.pop(CONF_TOKEN)
|
||||
assert data == {
|
||||
"auth_implementation": AUTH_DOMAIN,
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
CONF_AUTHORIZATION_URL: expected_authorize_url,
|
||||
CONF_TOKEN_URL: expected_token_url,
|
||||
}
|
||||
assert token
|
||||
token.pop("expires_at")
|
||||
assert token == OAUTH_TOKEN_PAYLOAD
|
||||
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "expected_error"),
|
||||
[
|
||||
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
|
||||
(
|
||||
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
|
||||
"cannot_connect",
|
||||
),
|
||||
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
|
||||
(Exception, "unknown"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@respx.mock
|
||||
async def test_oauth_discovery_failure(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
credential: None,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
side_effect: Exception,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test for an OAuth authentication flow for an MCP server."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
# MCP Server returns 401 indicating the client needs to authenticate
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"Authentication required", request=None, response=httpx.Response(401)
|
||||
)
|
||||
# Prepare the OAuth Server metadata
|
||||
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(side_effect=side_effect)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == expected_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "expected_error"),
|
||||
[
|
||||
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
|
||||
(
|
||||
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
|
||||
"cannot_connect",
|
||||
),
|
||||
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
|
||||
(Exception, "unknown"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@respx.mock
|
||||
async def test_authentication_flow_server_failure_abort(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
credential: None,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
side_effect: Exception,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test for an OAuth authentication flow for an MCP server."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
# MCP Server returns 401 indicating the client needs to authenticate
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"Authentication required", request=None, response=httpx.Response(401)
|
||||
)
|
||||
# Prepare the OAuth Server metadata
|
||||
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
|
||||
return_value=OAUTH_SERVER_METADATA_RESPONSE
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.MENU
|
||||
assert result["step_id"] == "credentials_choice"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"next_step_id": "pick_implementation",
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.EXTERNAL_STEP
|
||||
result = await perform_oauth_flow(
|
||||
hass,
|
||||
aioclient_mock,
|
||||
hass_client_no_auth,
|
||||
result,
|
||||
)
|
||||
|
||||
# Client fails with an error
|
||||
mock_mcp_client.side_effect = side_effect
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == expected_error
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@respx.mock
|
||||
async def test_authentication_flow_server_missing_tool_capabilities(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
credential: None,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test for an OAuth authentication flow for an MCP server."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
# MCP Server returns 401 indicating the client needs to authenticate
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"Authentication required", request=None, response=httpx.Response(401)
|
||||
)
|
||||
# Prepare the OAuth Server metadata
|
||||
respx.get(OAUTH_DISCOVERY_ENDPOINT).mock(
|
||||
return_value=OAUTH_SERVER_METADATA_RESPONSE
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.MENU
|
||||
assert result["step_id"] == "credentials_choice"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"next_step_id": "pick_implementation",
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.EXTERNAL_STEP
|
||||
result = await perform_oauth_flow(
|
||||
hass,
|
||||
aioclient_mock,
|
||||
hass_client_no_auth,
|
||||
result,
|
||||
)
|
||||
|
||||
# Client can now authenticate
|
||||
mock_mcp_client.side_effect = None
|
||||
|
||||
response = Mock()
|
||||
response.serverInfo.name = TEST_API_NAME
|
||||
response.capabilities.tools = None
|
||||
mock_mcp_client.return_value.initialize.return_value = response
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "missing_capabilities"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@respx.mock
|
||||
async def test_reauth_flow(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
credential: None,
|
||||
config_entry_with_auth: MockConfigEntry,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test for an OAuth authentication flow for an MCP server."""
|
||||
config_entry_with_auth.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
result = flows[0]
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
|
||||
result = await perform_oauth_flow(hass, aioclient_mock, hass_client_no_auth, result)
|
||||
|
||||
# Verify we can connect to the server
|
||||
response = Mock()
|
||||
response.serverInfo.name = TEST_API_NAME
|
||||
mock_mcp_client.return_value.initialize.return_value = response
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
|
||||
assert config_entry_with_auth.unique_id == AUTH_DOMAIN
|
||||
assert config_entry_with_auth.title == TEST_API_NAME
|
||||
data = {**config_entry_with_auth.data}
|
||||
token = data.pop(CONF_TOKEN)
|
||||
assert data == {
|
||||
"auth_implementation": AUTH_DOMAIN,
|
||||
CONF_URL: MCP_SERVER_URL,
|
||||
CONF_AUTHORIZATION_URL: OAUTH_AUTHORIZE_URL,
|
||||
CONF_TOKEN_URL: OAUTH_TOKEN_URL,
|
||||
}
|
||||
assert token
|
||||
token.pop("expires_at")
|
||||
assert token == OAUTH_TOKEN_PAYLOAD
|
||||
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
@ -76,17 +76,45 @@ async def test_init(
|
||||
assert config_entry.state is ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect"),
|
||||
[
|
||||
(httpx.TimeoutException("Some timeout")),
|
||||
(httpx.HTTPStatusError("", request=None, response=httpx.Response(500))),
|
||||
(httpx.HTTPStatusError("", request=None, response=httpx.Response(401))),
|
||||
(httpx.HTTPError("Some HTTP error")),
|
||||
],
|
||||
)
|
||||
async def test_mcp_server_failure(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
mock_mcp_client: Mock,
|
||||
side_effect: Exception,
|
||||
) -> None:
|
||||
"""Test the integration fails to setup if the server fails initialization."""
|
||||
mock_mcp_client.side_effect = side_effect
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_mcp_server_authentication_failure(
|
||||
hass: HomeAssistant,
|
||||
credential: None,
|
||||
config_entry_with_auth: MockConfigEntry,
|
||||
mock_mcp_client: Mock,
|
||||
) -> None:
|
||||
"""Test the integration fails to setup if the server fails authentication."""
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"", request=None, response=httpx.Response(500)
|
||||
"Authentication required", request=None, response=httpx.Response(401)
|
||||
)
|
||||
|
||||
with patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1):
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
await hass.config_entries.async_setup(config_entry_with_auth.entry_id)
|
||||
assert config_entry_with_auth.state is ConfigEntryState.SETUP_ERROR
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
assert flows[0]["step_id"] == "reauth_confirm"
|
||||
|
||||
|
||||
async def test_list_tools_failure(
|
||||
|
Loading…
x
Reference in New Issue
Block a user