mirror of
https://github.com/home-assistant/core.git
synced 2025-11-07 18:09:31 +00:00
282 lines
11 KiB
Python
282 lines
11 KiB
Python
"""Model Context Protocol transport protocol for Streamable HTTP and SSE.
|
|
|
|
This registers HTTP endpoints that support the Streamable HTTP protocol as
|
|
well as the older SSE as a transport layer.
|
|
|
|
The Streamable HTTP protocol uses a single HTTP endpoint:
|
|
|
|
- /api/mcp_server: The Streamable HTTP endpoint currently implements the
|
|
stateless protocol for simplicity. This receives client requests and
|
|
sends them to the MCP server, then waits for a response to send back to
|
|
the client.
|
|
|
|
The older SSE protocol has two HTTP endpoints:
|
|
|
|
- /mcp_server/sse: The SSE endpoint that is used to establish a session
|
|
with the client and glue to the MCP server. This is used to push responses
|
|
to the client.
|
|
- /mcp_server/messages: The endpoint that is used by the client to send
|
|
POST requests with new requests for the MCP server. The request contains
|
|
a session identifier. The response to the client is passed over the SSE
|
|
session started on the other endpoint.
|
|
|
|
See https://modelcontextprotocol.io/docs/concepts/transports
|
|
"""
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
from http import HTTPStatus
|
|
import logging
|
|
|
|
from aiohttp import web
|
|
from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound
|
|
from aiohttp_sse import sse_response
|
|
import anyio
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from mcp import JSONRPCRequest, types
|
|
from mcp.server import InitializationOptions, Server
|
|
from mcp.shared.message import SessionMessage
|
|
|
|
from homeassistant.components import conversation
|
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
|
from homeassistant.const import CONF_LLM_HASS_API
|
|
from homeassistant.core import Context, HomeAssistant, callback
|
|
from homeassistant.helpers import llm
|
|
|
|
from .const import DOMAIN
|
|
from .server import create_server
|
|
from .session import Session
|
|
from .types import MCPServerConfigEntry
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Streamable HTTP endpoint
|
|
STREAMABLE_API = "/api/mcp"
|
|
TIMEOUT = 60 # Seconds
|
|
|
|
# Content types
|
|
CONTENT_TYPE_JSON = "application/json"
|
|
|
|
# Legacy SSE endpoint
|
|
SSE_API = f"/{DOMAIN}/sse"
|
|
MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}"
|
|
|
|
|
|
@callback
|
|
def async_register(hass: HomeAssistant) -> None:
|
|
"""Register the websocket API."""
|
|
hass.http.register_view(ModelContextProtocolSSEView())
|
|
hass.http.register_view(ModelContextProtocolMessagesView())
|
|
hass.http.register_view(ModelContextProtocolStreamableView())
|
|
|
|
|
|
def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry:
|
|
"""Get the first enabled MCP server config entry.
|
|
|
|
The ConfigEntry contains a reference to the actual MCP server used to
|
|
serve the Model Context Protocol.
|
|
|
|
Will raise an HTTP error if the expected configuration is not present.
|
|
"""
|
|
config_entries: list[MCPServerConfigEntry] = (
|
|
hass.config_entries.async_loaded_entries(DOMAIN)
|
|
)
|
|
if not config_entries:
|
|
raise HTTPNotFound(text="Model Context Protocol server is not configured")
|
|
if len(config_entries) > 1:
|
|
raise HTTPNotFound(text="Found multiple Model Context Protocol configurations")
|
|
return config_entries[0]
|
|
|
|
|
|
@dataclass
|
|
class Streams:
|
|
"""Pairs of streams for MCP server communication."""
|
|
|
|
# The MCP server reads from the read stream. The HTTP handler receives
|
|
# incoming client messages and writes the to the read_stream_writer.
|
|
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
|
|
|
# The MCP server writes to the write stream. The HTTP handler reads from
|
|
# the write stream and sends messages to the client.
|
|
write_stream: MemoryObjectSendStream[SessionMessage]
|
|
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
|
|
|
|
|
def create_streams() -> Streams:
|
|
"""Create a new pair of streams for MCP server communication."""
|
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
|
return Streams(
|
|
read_stream=read_stream,
|
|
read_stream_writer=read_stream_writer,
|
|
write_stream=write_stream,
|
|
write_stream_reader=write_stream_reader,
|
|
)
|
|
|
|
|
|
async def create_mcp_server(
|
|
hass: HomeAssistant, context: Context, entry: MCPServerConfigEntry
|
|
) -> tuple[Server, InitializationOptions]:
|
|
"""Initialize the MCP server to ensure it's ready to handle requests."""
|
|
llm_context = llm.LLMContext(
|
|
platform=DOMAIN,
|
|
context=context,
|
|
language="*",
|
|
assistant=conversation.DOMAIN,
|
|
device_id=None,
|
|
)
|
|
llm_api_id = entry.data[CONF_LLM_HASS_API]
|
|
server = await create_server(hass, llm_api_id, llm_context)
|
|
options = await hass.async_add_executor_job(
|
|
server.create_initialization_options # Reads package for version info
|
|
)
|
|
return server, options
|
|
|
|
|
|
class ModelContextProtocolSSEView(HomeAssistantView):
|
|
"""Model Context Protocol SSE endpoint."""
|
|
|
|
name = f"{DOMAIN}:sse"
|
|
url = SSE_API
|
|
|
|
async def get(self, request: web.Request) -> web.StreamResponse:
|
|
"""Process SSE messages for the Model Context Protocol.
|
|
|
|
This is a long running request for the lifetime of the client session
|
|
and is the primary transport layer between the client and server.
|
|
|
|
Pairs of buffered streams act as a bridge between the transport protocol
|
|
(SSE over HTTP views) and the Model Context Protocol. The MCP SDK
|
|
manages all protocol details and invokes commands on our MCP server.
|
|
"""
|
|
hass = request.app[KEY_HASS]
|
|
entry = async_get_config_entry(hass)
|
|
session_manager = entry.runtime_data
|
|
|
|
server, options = await create_mcp_server(hass, self.context(request), entry)
|
|
streams = create_streams()
|
|
|
|
async with (
|
|
sse_response(request) as response,
|
|
session_manager.create(Session(streams.read_stream_writer)) as session_id,
|
|
):
|
|
session_uri = MESSAGES_API.format(session_id=session_id)
|
|
_LOGGER.debug("Sending SSE endpoint: %s", session_uri)
|
|
await response.send(session_uri, event="endpoint")
|
|
|
|
async def sse_reader() -> None:
|
|
"""Forward MCP server responses to the client."""
|
|
async for session_message in streams.write_stream_reader:
|
|
_LOGGER.debug("Sending SSE message: %s", session_message)
|
|
await response.send(
|
|
session_message.message.model_dump_json(
|
|
by_alias=True, exclude_none=True
|
|
),
|
|
event="message",
|
|
)
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
tg.start_soon(sse_reader)
|
|
await server.run(streams.read_stream, streams.write_stream, options)
|
|
|
|
return response
|
|
|
|
|
|
class ModelContextProtocolMessagesView(HomeAssistantView):
|
|
"""Model Context Protocol messages endpoint."""
|
|
|
|
name = f"{DOMAIN}:messages"
|
|
url = MESSAGES_API
|
|
|
|
async def post(
|
|
self,
|
|
request: web.Request,
|
|
session_id: str,
|
|
) -> web.StreamResponse:
|
|
"""Process incoming messages for the Model Context Protocol.
|
|
|
|
The request passes a session ID which is used to identify the original
|
|
SSE connection. This view parses incoming messages from the transport
|
|
layer then writes them to the MCP server stream for the session.
|
|
"""
|
|
hass = request.app[KEY_HASS]
|
|
config_entry = async_get_config_entry(hass)
|
|
|
|
session_manager = config_entry.runtime_data
|
|
if (session := session_manager.get(session_id)) is None:
|
|
_LOGGER.info("Could not find session ID: '%s'", session_id)
|
|
raise HTTPNotFound(text=f"Could not find session ID '{session_id}'")
|
|
|
|
json_data = await request.json()
|
|
try:
|
|
message = types.JSONRPCMessage.model_validate(json_data)
|
|
except ValueError as err:
|
|
_LOGGER.info("Failed to parse message: %s", err)
|
|
raise HTTPBadRequest(text="Could not parse message") from err
|
|
|
|
_LOGGER.debug("Received client message: %s", message)
|
|
await session.read_stream_writer.send(SessionMessage(message))
|
|
return web.Response(status=200)
|
|
|
|
|
|
class ModelContextProtocolStreamableView(HomeAssistantView):
|
|
"""Model Context Protocol Streamable HTTP endpoint."""
|
|
|
|
name = f"{DOMAIN}:streamable"
|
|
url = STREAMABLE_API
|
|
|
|
async def get(self, request: web.Request) -> web.StreamResponse:
|
|
"""Handle unsupported methods."""
|
|
return web.Response(
|
|
status=HTTPStatus.METHOD_NOT_ALLOWED, text="Only POST method is supported"
|
|
)
|
|
|
|
async def post(self, request: web.Request) -> web.StreamResponse:
|
|
"""Process JSON-RPC messages for the Model Context Protocol."""
|
|
hass = request.app[KEY_HASS]
|
|
entry = async_get_config_entry(hass)
|
|
|
|
# The request must include a JSON-RPC message
|
|
if CONTENT_TYPE_JSON not in request.headers.get("accept", ""):
|
|
raise HTTPBadRequest(text=f"Client must accept {CONTENT_TYPE_JSON}")
|
|
if request.content_type != CONTENT_TYPE_JSON:
|
|
raise HTTPBadRequest(text=f"Content-Type must be {CONTENT_TYPE_JSON}")
|
|
try:
|
|
json_data = await request.json()
|
|
message = types.JSONRPCMessage.model_validate(json_data)
|
|
except ValueError as err:
|
|
_LOGGER.debug("Failed to parse message as JSON-RPC message: %s", err)
|
|
raise HTTPBadRequest(text="Request must be a JSON-RPC message") from err
|
|
|
|
_LOGGER.debug("Received client message: %s", message)
|
|
|
|
# For notifications and responses only, return 202 Accepted
|
|
if not isinstance(message.root, JSONRPCRequest):
|
|
_LOGGER.debug("Notification or response received, returning 202")
|
|
return web.Response(status=HTTPStatus.ACCEPTED)
|
|
|
|
# The MCP server runs as a background task for the duration of the
|
|
# request. We open a buffered stream pair to communicate with it. The
|
|
# request is sent to the MCP server and we wait for a single response
|
|
# then shut down the server.
|
|
server, options = await create_mcp_server(hass, self.context(request), entry)
|
|
streams = create_streams()
|
|
|
|
async def run_server() -> None:
|
|
await server.run(
|
|
streams.read_stream, streams.write_stream, options, stateless=True
|
|
)
|
|
|
|
async with asyncio.timeout(TIMEOUT), anyio.create_task_group() as tg:
|
|
tg.start_soon(run_server)
|
|
|
|
await streams.read_stream_writer.send(SessionMessage(message))
|
|
session_message = await anext(streams.write_stream_reader)
|
|
tg.cancel_scope.cancel()
|
|
|
|
_LOGGER.debug("Sending response: %s", session_message)
|
|
return web.json_response(
|
|
data=session_message.message.model_dump(by_alias=True, exclude_none=True),
|
|
)
|