Files
core/homeassistant/components/mcp_server/http.py
2025-10-06 16:56:45 +02:00

282 lines
11 KiB
Python

"""Model Context Protocol transport protocol for Streamable HTTP and SSE.
This registers HTTP endpoints that support the Streamable HTTP protocol as
well as the older SSE as a transport layer.
The Streamable HTTP protocol uses a single HTTP endpoint:
- /api/mcp_server: The Streamable HTTP endpoint currently implements the
stateless protocol for simplicity. This receives client requests and
sends them to the MCP server, then waits for a response to send back to
the client.
The older SSE protocol has two HTTP endpoints:
- /mcp_server/sse: The SSE endpoint that is used to establish a session
with the client and glue to the MCP server. This is used to push responses
to the client.
- /mcp_server/messages: The endpoint that is used by the client to send
POST requests with new requests for the MCP server. The request contains
a session identifier. The response to the client is passed over the SSE
session started on the other endpoint.
See https://modelcontextprotocol.io/docs/concepts/transports
"""
import asyncio
from dataclasses import dataclass
from http import HTTPStatus
import logging
from aiohttp import web
from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound
from aiohttp_sse import sse_response
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import JSONRPCRequest, types
from mcp.server import InitializationOptions, Server
from mcp.shared.message import SessionMessage
from homeassistant.components import conversation
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers import llm
from .const import DOMAIN
from .server import create_server
from .session import Session
from .types import MCPServerConfigEntry
_LOGGER = logging.getLogger(__name__)
# Streamable HTTP endpoint
STREAMABLE_API = "/api/mcp"
TIMEOUT = 60 # Seconds
# Content types
CONTENT_TYPE_JSON = "application/json"
# Legacy SSE endpoint
SSE_API = f"/{DOMAIN}/sse"
MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}"
@callback
def async_register(hass: HomeAssistant) -> None:
"""Register the websocket API."""
hass.http.register_view(ModelContextProtocolSSEView())
hass.http.register_view(ModelContextProtocolMessagesView())
hass.http.register_view(ModelContextProtocolStreamableView())
def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry:
"""Get the first enabled MCP server config entry.
The ConfigEntry contains a reference to the actual MCP server used to
serve the Model Context Protocol.
Will raise an HTTP error if the expected configuration is not present.
"""
config_entries: list[MCPServerConfigEntry] = (
hass.config_entries.async_loaded_entries(DOMAIN)
)
if not config_entries:
raise HTTPNotFound(text="Model Context Protocol server is not configured")
if len(config_entries) > 1:
raise HTTPNotFound(text="Found multiple Model Context Protocol configurations")
return config_entries[0]
@dataclass
class Streams:
"""Pairs of streams for MCP server communication."""
# The MCP server reads from the read stream. The HTTP handler receives
# incoming client messages and writes the to the read_stream_writer.
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
# The MCP server writes to the write stream. The HTTP handler reads from
# the write stream and sends messages to the client.
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
def create_streams() -> Streams:
"""Create a new pair of streams for MCP server communication."""
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
return Streams(
read_stream=read_stream,
read_stream_writer=read_stream_writer,
write_stream=write_stream,
write_stream_reader=write_stream_reader,
)
async def create_mcp_server(
hass: HomeAssistant, context: Context, entry: MCPServerConfigEntry
) -> tuple[Server, InitializationOptions]:
"""Initialize the MCP server to ensure it's ready to handle requests."""
llm_context = llm.LLMContext(
platform=DOMAIN,
context=context,
language="*",
assistant=conversation.DOMAIN,
device_id=None,
)
llm_api_id = entry.data[CONF_LLM_HASS_API]
server = await create_server(hass, llm_api_id, llm_context)
options = await hass.async_add_executor_job(
server.create_initialization_options # Reads package for version info
)
return server, options
class ModelContextProtocolSSEView(HomeAssistantView):
"""Model Context Protocol SSE endpoint."""
name = f"{DOMAIN}:sse"
url = SSE_API
async def get(self, request: web.Request) -> web.StreamResponse:
"""Process SSE messages for the Model Context Protocol.
This is a long running request for the lifetime of the client session
and is the primary transport layer between the client and server.
Pairs of buffered streams act as a bridge between the transport protocol
(SSE over HTTP views) and the Model Context Protocol. The MCP SDK
manages all protocol details and invokes commands on our MCP server.
"""
hass = request.app[KEY_HASS]
entry = async_get_config_entry(hass)
session_manager = entry.runtime_data
server, options = await create_mcp_server(hass, self.context(request), entry)
streams = create_streams()
async with (
sse_response(request) as response,
session_manager.create(Session(streams.read_stream_writer)) as session_id,
):
session_uri = MESSAGES_API.format(session_id=session_id)
_LOGGER.debug("Sending SSE endpoint: %s", session_uri)
await response.send(session_uri, event="endpoint")
async def sse_reader() -> None:
"""Forward MCP server responses to the client."""
async for session_message in streams.write_stream_reader:
_LOGGER.debug("Sending SSE message: %s", session_message)
await response.send(
session_message.message.model_dump_json(
by_alias=True, exclude_none=True
),
event="message",
)
async with anyio.create_task_group() as tg:
tg.start_soon(sse_reader)
await server.run(streams.read_stream, streams.write_stream, options)
return response
class ModelContextProtocolMessagesView(HomeAssistantView):
"""Model Context Protocol messages endpoint."""
name = f"{DOMAIN}:messages"
url = MESSAGES_API
async def post(
self,
request: web.Request,
session_id: str,
) -> web.StreamResponse:
"""Process incoming messages for the Model Context Protocol.
The request passes a session ID which is used to identify the original
SSE connection. This view parses incoming messages from the transport
layer then writes them to the MCP server stream for the session.
"""
hass = request.app[KEY_HASS]
config_entry = async_get_config_entry(hass)
session_manager = config_entry.runtime_data
if (session := session_manager.get(session_id)) is None:
_LOGGER.info("Could not find session ID: '%s'", session_id)
raise HTTPNotFound(text=f"Could not find session ID '{session_id}'")
json_data = await request.json()
try:
message = types.JSONRPCMessage.model_validate(json_data)
except ValueError as err:
_LOGGER.info("Failed to parse message: %s", err)
raise HTTPBadRequest(text="Could not parse message") from err
_LOGGER.debug("Received client message: %s", message)
await session.read_stream_writer.send(SessionMessage(message))
return web.Response(status=200)
class ModelContextProtocolStreamableView(HomeAssistantView):
"""Model Context Protocol Streamable HTTP endpoint."""
name = f"{DOMAIN}:streamable"
url = STREAMABLE_API
async def get(self, request: web.Request) -> web.StreamResponse:
"""Handle unsupported methods."""
return web.Response(
status=HTTPStatus.METHOD_NOT_ALLOWED, text="Only POST method is supported"
)
async def post(self, request: web.Request) -> web.StreamResponse:
"""Process JSON-RPC messages for the Model Context Protocol."""
hass = request.app[KEY_HASS]
entry = async_get_config_entry(hass)
# The request must include a JSON-RPC message
if CONTENT_TYPE_JSON not in request.headers.get("accept", ""):
raise HTTPBadRequest(text=f"Client must accept {CONTENT_TYPE_JSON}")
if request.content_type != CONTENT_TYPE_JSON:
raise HTTPBadRequest(text=f"Content-Type must be {CONTENT_TYPE_JSON}")
try:
json_data = await request.json()
message = types.JSONRPCMessage.model_validate(json_data)
except ValueError as err:
_LOGGER.debug("Failed to parse message as JSON-RPC message: %s", err)
raise HTTPBadRequest(text="Request must be a JSON-RPC message") from err
_LOGGER.debug("Received client message: %s", message)
# For notifications and responses only, return 202 Accepted
if not isinstance(message.root, JSONRPCRequest):
_LOGGER.debug("Notification or response received, returning 202")
return web.Response(status=HTTPStatus.ACCEPTED)
# The MCP server runs as a background task for the duration of the
# request. We open a buffered stream pair to communicate with it. The
# request is sent to the MCP server and we wait for a single response
# then shut down the server.
server, options = await create_mcp_server(hass, self.context(request), entry)
streams = create_streams()
async def run_server() -> None:
await server.run(
streams.read_stream, streams.write_stream, options, stateless=True
)
async with asyncio.timeout(TIMEOUT), anyio.create_task_group() as tg:
tg.start_soon(run_server)
await streams.read_stream_writer.send(SessionMessage(message))
session_message = await anext(streams.write_stream_reader)
tg.cancel_scope.cancel()
_LOGGER.debug("Sending response: %s", session_message)
return web.json_response(
data=session_message.message.model_dump(by_alias=True, exclude_none=True),
)