mirror of
https://github.com/home-assistant/core.git
synced 2025-11-12 20:40:18 +00:00
Add the Model Context Protocol integration (#135058)
* Add the Model Context Protocol integration * Improvements to mcp integration * Move the API prompt constant * Update config flow error handling * Update test descriptions * Update tests/components/mcp/test_config_flow.py Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> * Update tests/components/mcp/test_config_flow.py Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> * Address PR feedback * Update homeassistant/components/mcp/coordinator.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Move tool parsing to the coordinator * Update session handling not to use a context manager --------- Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
225
tests/components/mcp/test_init.py
Normal file
225
tests/components/mcp/test_init.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for the Model Context Protocol component."""
|
||||
|
||||
import re
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from .conftest import TEST_API_NAME
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
SEARCH_MEMORY_TOOL = Tool(
|
||||
name="search_memory",
|
||||
description="Search memory for relevant context based on a query.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"required": ["query"],
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "A free text query to search context for.",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
SAVE_MEMORY_TOOL = Tool(
|
||||
name="save_memory",
|
||||
description="Save a memory context.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"required": ["context"],
|
||||
"properties": {
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "The context to save.",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "string",
|
||||
"description": "The key for the context.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_llm_context() -> llm.LLMContext:
|
||||
"""Create a test LLM context."""
|
||||
return llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=Context(),
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_init(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test the integration is initialized and can be unloaded cleanly."""
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
async def test_mcp_server_failure(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test the integration fails to setup if the server fails initialization."""
|
||||
mock_mcp_client.side_effect = httpx.HTTPStatusError(
|
||||
"", request=None, response=httpx.Response(500)
|
||||
)
|
||||
|
||||
with patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1):
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_list_tools_failure(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test the integration fails to load if the first data fetch returns an error."""
|
||||
mock_mcp_client.return_value.list_tools.side_effect = httpx.HTTPStatusError(
|
||||
"", request=None, response=httpx.Response(500)
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_llm_get_api_tools(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test MCP tools are returned as LLM API tools."""
|
||||
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
|
||||
tools=[SEARCH_MEMORY_TOOL, SAVE_MEMORY_TOOL],
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
apis = llm.async_get_apis(hass)
|
||||
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
|
||||
assert api
|
||||
|
||||
api_instance = await api.async_get_api_instance(create_llm_context())
|
||||
assert len(api_instance.tools) == 2
|
||||
tool = api_instance.tools[0]
|
||||
assert tool.name == "search_memory"
|
||||
assert tool.description == "Search memory for relevant context based on a query."
|
||||
with pytest.raises(
|
||||
vol.Invalid, match=re.escape("required key not provided @ data['query']")
|
||||
):
|
||||
tool.parameters({})
|
||||
assert tool.parameters({"query": "frogs"}) == {"query": "frogs"}
|
||||
|
||||
tool = api_instance.tools[1]
|
||||
assert tool.name == "save_memory"
|
||||
assert tool.description == "Save a memory context."
|
||||
with pytest.raises(
|
||||
vol.Invalid, match=re.escape("required key not provided @ data['context']")
|
||||
):
|
||||
tool.parameters({})
|
||||
assert tool.parameters({"context": {"fact": "User was born in February"}}) == {
|
||||
"context": {"fact": "User was born in February"}
|
||||
}
|
||||
|
||||
|
||||
async def test_call_tool(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test calling an MCP Tool through the LLM API."""
|
||||
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
|
||||
tools=[SEARCH_MEMORY_TOOL]
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
apis = llm.async_get_apis(hass)
|
||||
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
|
||||
assert api
|
||||
|
||||
api_instance = await api.async_get_api_instance(create_llm_context())
|
||||
assert len(api_instance.tools) == 1
|
||||
tool = api_instance.tools[0]
|
||||
assert tool.name == "search_memory"
|
||||
|
||||
mock_mcp_client.return_value.call_tool.return_value = CallToolResult(
|
||||
content=[TextContent(type="text", text="User was born in February")]
|
||||
)
|
||||
result = await tool.async_call(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="search_memory", tool_args={"query": "User's birth month"}
|
||||
),
|
||||
create_llm_context(),
|
||||
)
|
||||
assert result == {
|
||||
"content": [{"text": "User was born in February", "type": "text"}]
|
||||
}
|
||||
|
||||
|
||||
async def test_call_tool_fails(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test handling an MCP Tool call failure."""
|
||||
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
|
||||
tools=[SEARCH_MEMORY_TOOL]
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
apis = llm.async_get_apis(hass)
|
||||
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
|
||||
assert api
|
||||
|
||||
api_instance = await api.async_get_api_instance(create_llm_context())
|
||||
assert len(api_instance.tools) == 1
|
||||
tool = api_instance.tools[0]
|
||||
assert tool.name == "search_memory"
|
||||
|
||||
mock_mcp_client.return_value.call_tool.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=None, response=httpx.Response(500)
|
||||
)
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="Error when calling tool: Server error"
|
||||
):
|
||||
await tool.async_call(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="search_memory", tool_args={"query": "User's birth month"}
|
||||
),
|
||||
create_llm_context(),
|
||||
)
|
||||
|
||||
|
||||
async def test_convert_tool_schema_fails(
|
||||
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
|
||||
) -> None:
|
||||
"""Test a failure converting an MCP tool schema to a Home Assistant schema."""
|
||||
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
|
||||
tools=[SEARCH_MEMORY_TOOL]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.mcp.coordinator.convert_to_voluptuous",
|
||||
side_effect=ValueError,
|
||||
):
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
Reference in New Issue
Block a user