mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Add the Model Context Protocol Server integration (#134122)
* Add the Model Context Protocol Server integration * Remove unusued code in init * Fix comment wording * Use util.uild for unique ids * Set config entry title to the LLM API name * Extract an SSE parser and update comments * Update comments and defend against already closed sessions * Shorten description * Update homeassistant/components/mcp_server/__init__.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Change integration type to service --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
5e981d00a4
commit
a5d0c3528c
@ -311,6 +311,7 @@ homeassistant.components.manual.*
|
||||
homeassistant.components.mastodon.*
|
||||
homeassistant.components.matrix.*
|
||||
homeassistant.components.matter.*
|
||||
homeassistant.components.mcp_server.*
|
||||
homeassistant.components.mealie.*
|
||||
homeassistant.components.media_extractor.*
|
||||
homeassistant.components.media_player.*
|
||||
|
@ -889,6 +889,8 @@ build.json @home-assistant/supervisor
|
||||
/tests/components/matrix/ @PaarthShah
|
||||
/homeassistant/components/matter/ @home-assistant/matter
|
||||
/tests/components/matter/ @home-assistant/matter
|
||||
/homeassistant/components/mcp_server/ @allenporter
|
||||
/tests/components/mcp_server/ @allenporter
|
||||
/homeassistant/components/mealie/ @joostlek @andrew-codechimp
|
||||
/tests/components/mealie/ @joostlek @andrew-codechimp
|
||||
/homeassistant/components/meater/ @Sotolotl @emontnemery
|
||||
|
43
homeassistant/components/mcp_server/__init__.py
Normal file
43
homeassistant/components/mcp_server/__init__.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""The Model Context Protocol Server integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import http
|
||||
from .const import DOMAIN
|
||||
from .session import SessionManager
|
||||
from .types import MCPServerConfigEntry
|
||||
|
||||
__all__ = [
|
||||
"CONFIG_SCHEMA",
|
||||
"DOMAIN",
|
||||
"async_setup",
|
||||
"async_setup_entry",
|
||||
"async_unload_entry",
|
||||
]
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the Model Context Protocol component."""
|
||||
http.async_register(hass)
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: MCPServerConfigEntry) -> bool:
|
||||
"""Set up Model Context Protocol Server from a config entry."""
|
||||
|
||||
entry.runtime_data = SessionManager()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: MCPServerConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
session_manager = entry.runtime_data
|
||||
session_manager.close()
|
||||
return True
|
63
homeassistant/components/mcp_server/config_flow.py
Normal file
63
homeassistant/components/mcp_server/config_flow.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Config flow for the Model Context Protocol Server integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.selector import (
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
)
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
MORE_INFO_URL = "https://www.home-assistant.io/integrations/mcp_server/#configuration"
|
||||
|
||||
|
||||
class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Model Context Protocol Server."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
||||
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
default=llm.LLM_API_ASSIST,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(
|
||||
label=name,
|
||||
value=llm_api_id,
|
||||
)
|
||||
for llm_api_id, name in llm_apis.items()
|
||||
]
|
||||
)
|
||||
),
|
||||
}
|
||||
),
|
||||
description_placeholders={"more_info_url": MORE_INFO_URL},
|
||||
)
|
4
homeassistant/components/mcp_server/const.py
Normal file
4
homeassistant/components/mcp_server/const.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""Constants for the Model Context Protocol Server integration."""
|
||||
|
||||
DOMAIN = "mcp_server"
|
||||
TITLE = "Model Context Protocol Server"
|
170
homeassistant/components/mcp_server/http.py
Normal file
170
homeassistant/components/mcp_server/http.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""Model Context Protocol transport portocol for Server Sent Events (SSE).
|
||||
|
||||
This registers HTTP endpoints that supports SSE as a transport layer
|
||||
for the Model Context Protocol. There are two HTTP endpoints:
|
||||
|
||||
- /mcp_server/sse: The SSE endpoint that is used to establish a session
|
||||
with the client and glue to the MCP server. This is used to push responses
|
||||
to the client.
|
||||
- /mcp_server/messages: The endpoint that is used by the client to send
|
||||
POST requests with new requests for the MCP server. The request contains
|
||||
a session identifier. The response to the client is passed over the SSE
|
||||
session started on the other endpoint.
|
||||
|
||||
See https://modelcontextprotocol.io/docs/concepts/transports
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound
|
||||
from aiohttp_sse import sse_response
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from mcp import types
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from .const import DOMAIN
|
||||
from .server import create_server
|
||||
from .session import Session
|
||||
from .types import MCPServerConfigEntry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SSE_API = f"/{DOMAIN}/sse"
|
||||
MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}"
|
||||
|
||||
|
||||
@callback
|
||||
def async_register(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
hass.http.register_view(ModelContextProtocolSSEView())
|
||||
hass.http.register_view(ModelContextProtocolMessagesView())
|
||||
|
||||
|
||||
def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry:
|
||||
"""Get the first enabled MCP server config entry.
|
||||
|
||||
The ConfigEntry contains a reference to the actual MCP server used to
|
||||
serve the Model Context Protocol.
|
||||
|
||||
Will raise an HTTP error if the expected configuration is not present.
|
||||
"""
|
||||
config_entries: list[MCPServerConfigEntry] = [
|
||||
config_entry
|
||||
for config_entry in hass.config_entries.async_entries(DOMAIN)
|
||||
if config_entry.state == ConfigEntryState.LOADED
|
||||
]
|
||||
if not config_entries:
|
||||
raise HTTPNotFound(body="Model Context Protocol server is not configured")
|
||||
if len(config_entries) > 1:
|
||||
raise HTTPNotFound(body="Found multiple Model Context Protocol configurations")
|
||||
return config_entries[0]
|
||||
|
||||
|
||||
class ModelContextProtocolSSEView(HomeAssistantView):
|
||||
"""Model Context Protocol SSE endpoint."""
|
||||
|
||||
name = f"{DOMAIN}:sse"
|
||||
url = SSE_API
|
||||
|
||||
async def get(self, request: web.Request) -> web.StreamResponse:
|
||||
"""Process SSE messages for the Model Context Protocol.
|
||||
|
||||
This is a long running request for the lifetime of the client session
|
||||
and is the primary transport layer between the client and server.
|
||||
|
||||
Pairs of buffered streams act as a bridge between the transport protocol
|
||||
(SSE over HTTP views) and the Model Context Protocol. The MCP SDK
|
||||
manages all protocol details and invokes commands on our MCP server.
|
||||
"""
|
||||
hass = request.app[KEY_HASS]
|
||||
entry = async_get_config_entry(hass)
|
||||
session_manager = entry.runtime_data
|
||||
|
||||
context = llm.LLMContext(
|
||||
platform=DOMAIN,
|
||||
context=self.context(request),
|
||||
user_prompt=None,
|
||||
language="*",
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=None,
|
||||
)
|
||||
llm_api_id = entry.data[CONF_LLM_HASS_API]
|
||||
server = await create_server(hass, llm_api_id, context)
|
||||
options = await hass.async_add_executor_job(
|
||||
server.create_initialization_options # Reads package for version info
|
||||
)
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with (
|
||||
sse_response(request) as response,
|
||||
session_manager.create(Session(read_stream_writer)) as session_id,
|
||||
):
|
||||
session_uri = MESSAGES_API.format(session_id=session_id)
|
||||
_LOGGER.debug("Sending SSE endpoint: %s", session_uri)
|
||||
await response.send(session_uri, event="endpoint")
|
||||
|
||||
async def sse_reader() -> None:
|
||||
"""Forward MCP server responses to the client."""
|
||||
async for message in write_stream_reader:
|
||||
_LOGGER.debug("Sending SSE message: %s", message)
|
||||
await response.send(
|
||||
message.model_dump_json(by_alias=True, exclude_none=True),
|
||||
event="message",
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(sse_reader)
|
||||
await server.run(read_stream, write_stream, options)
|
||||
return response
|
||||
|
||||
|
||||
class ModelContextProtocolMessagesView(HomeAssistantView):
|
||||
"""Model Context Protocol messages endpoint."""
|
||||
|
||||
name = f"{DOMAIN}:messages"
|
||||
url = MESSAGES_API
|
||||
|
||||
async def post(
|
||||
self,
|
||||
request: web.Request,
|
||||
session_id: str,
|
||||
) -> web.StreamResponse:
|
||||
"""Process incoming messages for the Model Context Protocol.
|
||||
|
||||
The request passes a session ID which is used to identify the original
|
||||
SSE connection. This view parses incoming messagess from the transport
|
||||
layer then writes them to the MCP server stream for the session.
|
||||
"""
|
||||
hass = request.app[KEY_HASS]
|
||||
config_entry = async_get_config_entry(hass)
|
||||
|
||||
session_manager = config_entry.runtime_data
|
||||
if (session := session_manager.get(session_id)) is None:
|
||||
_LOGGER.info("Could not find session ID: '%s'", session_id)
|
||||
raise HTTPNotFound(body=f"Could not find session ID '{session_id}'")
|
||||
|
||||
json_data = await request.json()
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate(json_data)
|
||||
except ValueError as err:
|
||||
_LOGGER.info("Failed to parse message: %s", err)
|
||||
raise HTTPBadRequest(body="Could not parse message") from err
|
||||
|
||||
_LOGGER.debug("Received client message: %s", message)
|
||||
await session.read_stream_writer.send(message)
|
||||
return web.Response(status=200)
|
13
homeassistant/components/mcp_server/manifest.json
Normal file
13
homeassistant/components/mcp_server/manifest.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"domain": "mcp_server",
|
||||
"name": "Model Context Protocol Server",
|
||||
"codeowners": ["@allenporter"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["homeassistant", "http", "conversation"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/mcp_server",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "silver",
|
||||
"requirements": ["mcp==1.1.2", "aiohttp_sse==2.2.0", "anyio==4.7.0"],
|
||||
"single_config_entry": true
|
||||
}
|
118
homeassistant/components/mcp_server/quality_scale.yaml
Normal file
118
homeassistant/components/mcp_server/quality_scale.yaml
Normal file
@ -0,0 +1,118 @@
|
||||
rules:
|
||||
# Bronze
|
||||
action-setup:
|
||||
status: exempt
|
||||
comment: Service does not register actions
|
||||
appropriate-polling:
|
||||
status: exempt
|
||||
comment: Service is not polling
|
||||
brands: done
|
||||
common-modules:
|
||||
status: exempt
|
||||
comment: Service does not have entities or coordinators
|
||||
config-flow-test-coverage: done
|
||||
config-flow: done
|
||||
dependency-transparency: done
|
||||
docs-actions:
|
||||
status: exempt
|
||||
comment: Service does not register actions
|
||||
docs-high-level-description: done
|
||||
docs-installation-instructions: done
|
||||
docs-removal-instructions: done
|
||||
entity-event-setup:
|
||||
status: exempt
|
||||
comment: Service does not subscribe to events
|
||||
entity-unique-id:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
has-entity-name:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
runtime-data:
|
||||
status: exempt
|
||||
comment: No configuration state is used by the integration
|
||||
test-before-configure:
|
||||
status: exempt
|
||||
comment: Service does not a connection
|
||||
test-before-setup:
|
||||
status: exempt
|
||||
comment: Service does not a connection
|
||||
unique-config-entry:
|
||||
status: done
|
||||
comment: Integration requires a single config entry.
|
||||
|
||||
# Silver
|
||||
action-exceptions:
|
||||
status: exempt
|
||||
comment: Service does not register actions
|
||||
config-entry-unloading: done
|
||||
docs-configuration-parameters: done
|
||||
docs-installation-parameters: done
|
||||
entity-unavailable:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
integration-owner: done
|
||||
log-when-unavailable:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
parallel-updates:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
reauthentication-flow:
|
||||
status: exempt
|
||||
comment: Service does not require authentication
|
||||
test-coverage: done
|
||||
|
||||
# Gold
|
||||
devices:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
diagnostics: todo
|
||||
discovery-update-info:
|
||||
status: exempt
|
||||
comment: Service does not support discovery
|
||||
discovery:
|
||||
status: exempt
|
||||
comment: Service does not support discovery
|
||||
docs-data-update: done
|
||||
docs-examples: done
|
||||
docs-known-limitations: done
|
||||
docs-supported-devices: done
|
||||
docs-supported-functions: done
|
||||
docs-troubleshooting: todo
|
||||
docs-use-cases: done
|
||||
dynamic-devices:
|
||||
status: exempt
|
||||
comment: Service does not support devices
|
||||
entity-category:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
entity-device-class:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
entity-disabled-by-default:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
entity-translations:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
exception-translations: todo
|
||||
icon-translations:
|
||||
status: exempt
|
||||
comment: Service does not have entities
|
||||
reconfiguration-flow: todo
|
||||
repair-issues:
|
||||
status: exempt
|
||||
comment: Service does not have anything to repair
|
||||
stale-devices:
|
||||
status: exempt
|
||||
comment: Service does not have devices
|
||||
|
||||
# Platinum
|
||||
async-dependency:
|
||||
status: exempt
|
||||
comment: Service does not communicate with devices
|
||||
inject-websession:
|
||||
status: exempt
|
||||
comment: Service does not communicate with devices
|
||||
strict-typing: done
|
77
homeassistant/components/mcp_server/server.py
Normal file
77
homeassistant/components/mcp_server/server.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""The Model Context Protocol Server implementation.
|
||||
|
||||
The Model Context Protocol python sdk defines a Server API that provides the
|
||||
MCP message handling logic and error handling. The server implementation provided
|
||||
here is independent of the lower level transport protocol.
|
||||
|
||||
See https://modelcontextprotocol.io/docs/concepts/architecture#implementation-example
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from mcp import types
|
||||
from mcp.server import Server
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> types.Tool:
|
||||
"""Format tool specification."""
|
||||
input_schema = convert(tool.parameters, custom_serializer=custom_serializer)
|
||||
return types.Tool(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": input_schema["properties"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def create_server(
|
||||
hass: HomeAssistant, llm_api_id: str, llm_context: llm.LLMContext
|
||||
) -> Server:
|
||||
"""Create a new Model Context Protocol Server.
|
||||
|
||||
A Model Context Protocol Server object is associated with a single session.
|
||||
The MCP SDK handles the details of the protocol.
|
||||
"""
|
||||
|
||||
server = Server("home-assistant")
|
||||
|
||||
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
"""List available time tools."""
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
||||
|
||||
@server.call_tool() # type: ignore[no-untyped-call, misc]
|
||||
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
|
||||
"""Handle calling tools."""
|
||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
|
||||
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
||||
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
raise HomeAssistantError(f"Error calling tool: {e}") from e
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=json.dumps(tool_response),
|
||||
)
|
||||
]
|
||||
|
||||
return server
|
60
homeassistant/components/mcp_server/session.py
Normal file
60
homeassistant/components/mcp_server/session.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Model Context Protocol sessions.
|
||||
|
||||
A session is a long-lived connection between the client and server that is used
|
||||
to exchange messages. The server pushes messages to the client over the session
|
||||
and the client sends messages to the server over the session.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from mcp import types
|
||||
|
||||
from homeassistant.util import ulid
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""A session for the Model Context Protocol."""
|
||||
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manage SSE sessions for the MCP transport layer.
|
||||
|
||||
This class is used to manage the lifecycle of SSE sessions. It is responsible for
|
||||
creating new sessions, resuming existing sessions, and closing sessions.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the SSE server transport."""
|
||||
self._sessions: dict[str, Session] = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def create(self, session: Session) -> AsyncGenerator[str]:
|
||||
"""Context manager to create a new session ID and close when done."""
|
||||
session_id = ulid.ulid_now()
|
||||
_LOGGER.debug("Creating session: %s", session_id)
|
||||
self._sessions[session_id] = session
|
||||
try:
|
||||
yield session_id
|
||||
finally:
|
||||
_LOGGER.debug("Closing session: %s", session_id)
|
||||
if session_id in self._sessions: # close() may have already been called
|
||||
self._sessions.pop(session_id)
|
||||
|
||||
def get(self, session_id: str) -> Session | None:
|
||||
"""Get an existing session."""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close any open sessions."""
|
||||
for session in self._sessions.values():
|
||||
session.read_stream_writer.close()
|
||||
self._sessions.clear()
|
18
homeassistant/components/mcp_server/strings.json
Normal file
18
homeassistant/components/mcp_server/strings.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"description": "See the [integration documentation]({more_info_url}) for setup instructions.",
|
||||
"data": {
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
},
|
||||
"data_description": {
|
||||
"llm_hass_api": "The method for controling Home Assistant to expose with the Model Context Protocol."
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
||||
}
|
||||
}
|
||||
}
|
7
homeassistant/components/mcp_server/types.py
Normal file
7
homeassistant/components/mcp_server/types.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Types for the MCP server integration."""
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
||||
from .session import SessionManager
|
||||
|
||||
type MCPServerConfigEntry = ConfigEntry[SessionManager]
|
@ -356,6 +356,7 @@ FLOWS = {
|
||||
"mailgun",
|
||||
"mastodon",
|
||||
"matter",
|
||||
"mcp_server",
|
||||
"mealie",
|
||||
"meater",
|
||||
"medcom_ble",
|
||||
|
@ -3590,6 +3590,13 @@
|
||||
"config_flow": true,
|
||||
"iot_class": "local_push"
|
||||
},
|
||||
"mcp_server": {
|
||||
"name": "Model Context Protocol Server",
|
||||
"integration_type": "service",
|
||||
"config_flow": true,
|
||||
"iot_class": "local_push",
|
||||
"single_config_entry": true
|
||||
},
|
||||
"mealie": {
|
||||
"name": "Mealie",
|
||||
"integration_type": "service",
|
||||
|
10
mypy.ini
10
mypy.ini
@ -2866,6 +2866,16 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.mcp_server.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.mealie.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
@ -266,6 +266,9 @@ aiohasupervisor==0.2.2b5
|
||||
# homeassistant.components.homekit_controller
|
||||
aiohomekit==3.2.7
|
||||
|
||||
# homeassistant.components.mcp_server
|
||||
aiohttp_sse==2.2.0
|
||||
|
||||
# homeassistant.components.hue
|
||||
aiohue==4.7.3
|
||||
|
||||
@ -466,6 +469,9 @@ anthemav==1.4.1
|
||||
# homeassistant.components.anthropic
|
||||
anthropic==0.31.2
|
||||
|
||||
# homeassistant.components.mcp_server
|
||||
anyio==4.7.0
|
||||
|
||||
# homeassistant.components.weatherkit
|
||||
apple_weatherkit==1.1.3
|
||||
|
||||
@ -1355,6 +1361,9 @@ maxcube-api==0.4.3
|
||||
# homeassistant.components.mythicbeastsdns
|
||||
mbddns==0.1.2
|
||||
|
||||
# homeassistant.components.mcp_server
|
||||
mcp==1.1.2
|
||||
|
||||
# homeassistant.components.minecraft_server
|
||||
mcstatus==11.1.1
|
||||
|
||||
|
@ -251,6 +251,9 @@ aiohasupervisor==0.2.2b5
|
||||
# homeassistant.components.homekit_controller
|
||||
aiohomekit==3.2.7
|
||||
|
||||
# homeassistant.components.mcp_server
|
||||
aiohttp_sse==2.2.0
|
||||
|
||||
# homeassistant.components.hue
|
||||
aiohue==4.7.3
|
||||
|
||||
@ -439,6 +442,9 @@ anthemav==1.4.1
|
||||
# homeassistant.components.anthropic
|
||||
anthropic==0.31.2
|
||||
|
||||
# homeassistant.components.mcp_server
|
||||
anyio==4.7.0
|
||||
|
||||
# homeassistant.components.weatherkit
|
||||
apple_weatherkit==1.1.3
|
||||
|
||||
@ -1133,6 +1139,9 @@ maxcube-api==0.4.3
|
||||
# homeassistant.components.mythicbeastsdns
|
||||
mbddns==0.1.2
|
||||
|
||||
# homeassistant.components.mcp_server
|
||||
mcp==1.1.2
|
||||
|
||||
# homeassistant.components.minecraft_server
|
||||
mcstatus==11.1.1
|
||||
|
||||
|
1
tests/components/mcp_server/__init__.py
Normal file
1
tests/components/mcp_server/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for the Model Context Protocol Server integration."""
|
35
tests/components/mcp_server/conftest.py
Normal file
35
tests/components/mcp_server/conftest.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Common fixtures for the Model Context Protocol Server tests."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.mcp_server.const import DOMAIN
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry() -> Generator[AsyncMock]:
|
||||
"""Override async_setup_entry."""
|
||||
with patch(
|
||||
"homeassistant.components.mcp_server.async_setup_entry", return_value=True
|
||||
) as mock_setup_entry:
|
||||
yield mock_setup_entry
|
||||
|
||||
|
||||
@pytest.fixture(name="config_entry")
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"""Fixture to load the integration."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
},
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
return config_entry
|
41
tests/components/mcp_server/test_config_flow.py
Normal file
41
tests/components/mcp_server/test_config_flow.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Test the Model Context Protocol Server config flow."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.mcp_server.const import DOMAIN
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
{},
|
||||
{CONF_LLM_HASS_API: "assist"},
|
||||
],
|
||||
)
|
||||
async def test_form(
|
||||
hass: HomeAssistant, mock_setup_entry: AsyncMock, params: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test we get the form."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert not result["errors"]
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
params,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Assist"
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert result["data"] == {CONF_LLM_HASS_API: "assist"}
|
356
tests/components/mcp_server/test_http.py
Normal file
356
tests/components/mcp_server/test_http.py
Normal file
@ -0,0 +1,356 @@
|
||||
"""Test the Model Context Protocol Server init module."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
import mcp
|
||||
import mcp.client.session
|
||||
import mcp.client.sse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, setup_test_component_platform
|
||||
from tests.components.light.common import MockLight
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TEST_ENTITY = "light.kitchen"
|
||||
INITIALIZE_MESSAGE = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "request-id-1",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "1.0",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "test",
|
||||
"version": "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
EVENT_PREFIX = "event: "
|
||||
DATA_PREFIX = "data: "
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_integration(hass: HomeAssistant, config_entry: MockConfigEntry) -> None:
|
||||
"""Set up the config entry."""
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_entities(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
setup_integration: None,
|
||||
) -> None:
|
||||
"""Fixture to expose entities to the conversation agent."""
|
||||
entity = MockLight("kitchen", STATE_OFF)
|
||||
entity.entity_id = TEST_ENTITY
|
||||
setup_test_component_platform(hass, LIGHT_DOMAIN, [entity])
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
LIGHT_DOMAIN,
|
||||
{LIGHT_DOMAIN: [{"platform": "test"}]},
|
||||
)
|
||||
|
||||
async_expose_entity(hass, CONVERSATION_DOMAIN, TEST_ENTITY, True)
|
||||
|
||||
|
||||
async def sse_response_reader(
|
||||
response: aiohttp.ClientResponse,
|
||||
) -> AsyncGenerator[tuple[str, str]]:
|
||||
"""Read SSE responses from the server and emit event messages.
|
||||
|
||||
SSE responses are formatted as:
|
||||
event: event-name
|
||||
data: event-data
|
||||
and this function emits each event-name and event-data as a tuple.
|
||||
"""
|
||||
it = aiter(response.content)
|
||||
while True:
|
||||
line = (await anext(it)).decode()
|
||||
if not line.startswith(EVENT_PREFIX):
|
||||
raise ValueError("Expected event")
|
||||
event = line[len(EVENT_PREFIX) :].strip()
|
||||
line = (await anext(it)).decode()
|
||||
if not line.startswith(DATA_PREFIX):
|
||||
raise ValueError("Expected data")
|
||||
data = line[len(DATA_PREFIX) :].strip()
|
||||
line = (await anext(it)).decode()
|
||||
assert line == "\r\n"
|
||||
yield event, data
|
||||
|
||||
|
||||
async def test_http_sse(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test SSE endpoint can be used to receive MCP messages."""
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
# Start an SSE session
|
||||
response = await client.get(SSE_API)
|
||||
assert response.status == HTTPStatus.OK
|
||||
|
||||
# Decode a single SSE response that sends the messages endpoint
|
||||
reader = sse_response_reader(response)
|
||||
event, endpoint_url = await anext(reader)
|
||||
assert event == "endpoint"
|
||||
|
||||
# Send an initialize message on the messages endpoint
|
||||
response = await client.post(endpoint_url, json=INITIALIZE_MESSAGE)
|
||||
assert response.status == HTTPStatus.OK
|
||||
|
||||
# Decode the initialize response event message from the SSE stream
|
||||
event, data = await anext(reader)
|
||||
assert event == "message"
|
||||
message = json.loads(data)
|
||||
assert message.get("jsonrpc") == "2.0"
|
||||
assert message.get("id") == "request-id-1"
|
||||
assert "serverInfo" in message.get("result", {})
|
||||
assert "protocolVersion" in message.get("result", {})
|
||||
|
||||
|
||||
async def test_http_messages_missing_session_id(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test the tools list endpoint."""
|
||||
|
||||
client = await hass_client()
|
||||
response = await client.post(MESSAGES_API.format(session_id="invalid-session-id"))
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
response_data = await response.text()
|
||||
assert response_data == "Could not find session ID 'invalid-session-id'"
|
||||
|
||||
|
||||
async def test_http_messages_invalid_message_format(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test the tools list endpoint."""
|
||||
|
||||
client = await hass_client()
|
||||
response = await client.get(SSE_API)
|
||||
assert response.status == HTTPStatus.OK
|
||||
reader = sse_response_reader(response)
|
||||
event, endpoint_url = await anext(reader)
|
||||
assert event == "endpoint"
|
||||
|
||||
response = await client.post(endpoint_url, json={"invalid": "message"})
|
||||
assert response.status == HTTPStatus.BAD_REQUEST
|
||||
response_data = await response.text()
|
||||
assert response_data == "Could not parse message"
|
||||
|
||||
|
||||
async def test_http_sse_multiple_config_entries(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test the SSE endpoint will fail with multiple config entries.
|
||||
|
||||
This cannot happen in practice as the integration only supports a single
|
||||
config entry, but this is added for test coverage.
|
||||
"""
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain="mcp_server", data={CONF_LLM_HASS_API: "llm-api-id"}
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
# Attempt to start an SSE session will fail
|
||||
response = await client.get(SSE_API)
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
response_data = await response.text()
|
||||
assert "Found multiple Model Context Protocol" in response_data
|
||||
|
||||
|
||||
async def test_http_sse_no_config_entry(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
config_entry: MockConfigEntry,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test the SSE endpoint fails with a missing config entry."""
|
||||
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.NOT_LOADED
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
# Start an SSE session
|
||||
response = await client.get(SSE_API)
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
response_data = await response.text()
|
||||
assert "Model Context Protocol server is not configured" in response_data
|
||||
|
||||
|
||||
async def test_http_messages_no_config_entry(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
config_entry: MockConfigEntry,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test the message endpoint will fail if the config entry is unloaded."""
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
# Start an SSE session
|
||||
response = await client.get(SSE_API)
|
||||
assert response.status == HTTPStatus.OK
|
||||
reader = sse_response_reader(response)
|
||||
event, endpoint_url = await anext(reader)
|
||||
assert event == "endpoint"
|
||||
|
||||
# Invalidate the session by unloading the config entry
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.NOT_LOADED
|
||||
|
||||
# Reload the config entry and ensure the session is not found
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
response = await client.post(endpoint_url, json=INITIALIZE_MESSAGE)
|
||||
assert response.status == HTTPStatus.NOT_FOUND
|
||||
response_data = await response.text()
|
||||
assert "Could not find session ID" in response_data
|
||||
|
||||
|
||||
async def test_http_requires_authentication(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test the SSE endpoint requires authentication."""
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
response = await client.get(SSE_API)
|
||||
assert response.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
response = await client.post(MESSAGES_API.format(session_id="session-id"))
|
||||
assert response.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_sse_url(hass_client: ClientSessionGenerator) -> str:
|
||||
"""Fixture to get the MCP integration SSE URL."""
|
||||
client = await hass_client()
|
||||
return str(client.make_url(SSE_API))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def mcp_session(
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> AsyncGenerator[mcp.client.session.ClientSession]:
|
||||
"""Create an MCP session."""
|
||||
|
||||
headers = {"Authorization": f"Bearer {hass_supervisor_access_token}"}
|
||||
|
||||
async with (
|
||||
mcp.client.sse.sse_client(mcp_sse_url, headers=headers) as streams,
|
||||
mcp.client.session.ClientSession(*streams) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
|
||||
async def test_mcp_tools_list(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> None:
|
||||
"""Test the tools list endpoint."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.list_tools()
|
||||
|
||||
# Pick a single arbitrary tool and test that description and parameters
|
||||
# are converted correctly.
|
||||
tool = next(iter(tool for tool in result.tools if tool.name == "HassTurnOn"))
|
||||
assert tool.name == "HassTurnOn"
|
||||
assert tool.description == "Turns on/opens a device or entity"
|
||||
assert tool.inputSchema
|
||||
assert tool.inputSchema.get("type") == "object"
|
||||
properties = tool.inputSchema.get("properties")
|
||||
assert properties.get("name") == {"type": "string"}
|
||||
|
||||
|
||||
async def test_mcp_tool_call(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> None:
|
||||
"""Test the tool call endpoint."""
|
||||
|
||||
state = hass.states.get("light.kitchen")
|
||||
assert state
|
||||
assert state.state == STATE_OFF
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.call_tool(
|
||||
name="HassTurnOn",
|
||||
arguments={"name": "kitchen"},
|
||||
)
|
||||
|
||||
assert not result.isError
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0].type == "text"
|
||||
# The content is the raw tool call payload
|
||||
content = json.loads(result.content[0].text)
|
||||
assert content.get("data", {}).get("success")
|
||||
assert not content.get("data", {}).get("failed")
|
||||
|
||||
# Verify tool call invocation
|
||||
state = hass.states.get("light.kitchen")
|
||||
assert state
|
||||
assert state.state == STATE_ON
|
||||
|
||||
|
||||
async def test_mcp_tool_call_failed(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: None,
|
||||
mcp_sse_url: str,
|
||||
hass_supervisor_access_token: str,
|
||||
) -> None:
|
||||
"""Test the tool call endpoint with a failure."""
|
||||
|
||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||
result = await session.call_tool(
|
||||
name="HassTurnOn",
|
||||
arguments={"name": "backyard"},
|
||||
)
|
||||
|
||||
assert result.isError
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0].type == "text"
|
||||
assert "Error calling tool" in result.content[0].text
|
15
tests/components/mcp_server/test_init.py
Normal file
15
tests/components/mcp_server/test_init.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Test the Model Context Protocol Server init module."""
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_init(hass: HomeAssistant, config_entry: MockConfigEntry) -> None:
|
||||
"""Test the integration is initialized and can be unloaded cleanly."""
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.NOT_LOADED
|
Loading…
x
Reference in New Issue
Block a user