diff --git a/.strict-typing b/.strict-typing index 1c0456a745d..62da6c5ca92 100644 --- a/.strict-typing +++ b/.strict-typing @@ -316,6 +316,7 @@ homeassistant.components.manual.* homeassistant.components.mastodon.* homeassistant.components.matrix.* homeassistant.components.matter.* +homeassistant.components.mcp.* homeassistant.components.mcp_server.* homeassistant.components.mealie.* homeassistant.components.media_extractor.* diff --git a/CODEOWNERS b/CODEOWNERS index f16b890d407..faded2af138 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -891,6 +891,8 @@ build.json @home-assistant/supervisor /tests/components/matrix/ @PaarthShah /homeassistant/components/matter/ @home-assistant/matter /tests/components/matter/ @home-assistant/matter +/homeassistant/components/mcp/ @allenporter +/tests/components/mcp/ @allenporter /homeassistant/components/mcp_server/ @allenporter /tests/components/mcp_server/ @allenporter /homeassistant/components/mealie/ @joostlek @andrew-codechimp diff --git a/homeassistant/components/mcp/__init__.py b/homeassistant/components/mcp/__init__.py new file mode 100644 index 00000000000..4a2b4da990d --- /dev/null +++ b/homeassistant/components/mcp/__init__.py @@ -0,0 +1,69 @@ +"""The Model Context Protocol integration.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm + +from .const import DOMAIN +from .coordinator import ModelContextProtocolCoordinator +from .types import ModelContextProtocolConfigEntry + +__all__ = [ + "DOMAIN", + "async_setup_entry", + "async_unload_entry", +] + +API_PROMPT = "The following tools are available from a remote server named {name}." + + +async def async_setup_entry( + hass: HomeAssistant, entry: ModelContextProtocolConfigEntry +) -> bool: + """Set up Model Context Protocol from a config entry.""" + coordinator = ModelContextProtocolCoordinator(hass, entry) + await coordinator.async_config_entry_first_refresh() + + unsub = llm.async_register_api( + hass, + ModelContextProtocolAPI( + hass=hass, + id=f"{DOMAIN}-{entry.entry_id}", + name=entry.title, + coordinator=coordinator, + ), + ) + entry.async_on_unload(unsub) + + entry.runtime_data = coordinator + entry.async_on_unload(coordinator.close) + + return True + + +async def async_unload_entry( + hass: HomeAssistant, entry: ModelContextProtocolConfigEntry +) -> bool: + """Unload a config entry.""" + return True + + +@dataclass(kw_only=True) +class ModelContextProtocolAPI(llm.API): + """Define an object to hold the Model Context Protocol API.""" + + coordinator: ModelContextProtocolCoordinator + + async def async_get_api_instance( + self, llm_context: llm.LLMContext + ) -> llm.APIInstance: + """Return the instance of the API.""" + return llm.APIInstance( + self, + API_PROMPT.format(name=self.name), + llm_context, + tools=self.coordinator.data, + ) diff --git a/homeassistant/components/mcp/config_flow.py b/homeassistant/components/mcp/config_flow.py new file mode 100644 index 00000000000..92e0052c665 --- /dev/null +++ b/homeassistant/components/mcp/config_flow.py @@ -0,0 +1,111 @@ +"""Config flow for the Model Context Protocol integration.""" + +from __future__ import annotations + +import logging +from typing import Any + +import httpx +import voluptuous as vol + +from homeassistant.config_entries import ConfigFlow, ConfigFlowResult +from homeassistant.const import CONF_URL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_validation as cv + +from .const import DOMAIN +from .coordinator import mcp_client + +_LOGGER = logging.getLogger(__name__) + +STEP_USER_DATA_SCHEMA = vol.Schema( + { + vol.Required(CONF_URL): str, + } +) + + +async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, Any]: + """Validate the user input and connect to the MCP server.""" + url = data[CONF_URL] + try: + cv.url(url) # Cannot be added to schema directly + except vol.Invalid as error: + raise InvalidUrl from error + try: + async with mcp_client(url) as session: + response = await session.initialize() + except httpx.TimeoutException as error: + _LOGGER.info("Timeout connecting to MCP server: %s", error) + raise TimeoutConnectError from error + except httpx.HTTPStatusError as error: + _LOGGER.info("Cannot connect to MCP server: %s", error) + if error.response.status_code == 401: + raise InvalidAuth from error + raise CannotConnect from error + except httpx.HTTPError as error: + _LOGGER.info("Cannot connect to MCP server: %s", error) + raise CannotConnect from error + + if not response.capabilities.tools: + raise MissingCapabilities( + f"MCP Server {url} does not support 'Tools' capability" + ) + + return {"title": response.serverInfo.name} + + +class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN): + """Handle a config flow for Model Context Protocol.""" + + VERSION = 1 + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle the initial step.""" + errors: dict[str, str] = {} + if user_input is not None: + try: + info = await validate_input(self.hass, user_input) + except InvalidUrl: + errors[CONF_URL] = "invalid_url" + except TimeoutConnectError: + errors["base"] = "timeout_connect" + except CannotConnect: + errors["base"] = "cannot_connect" + except InvalidAuth: + return self.async_abort(reason="invalid_auth") + except MissingCapabilities: + return self.async_abort(reason="missing_capabilities") + except Exception: + _LOGGER.exception("Unexpected exception") + errors["base"] = "unknown" + else: + self._async_abort_entries_match({CONF_URL: user_input[CONF_URL]}) + return self.async_create_entry(title=info["title"], data=user_input) + + return self.async_show_form( + step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors + ) + + +class InvalidUrl(HomeAssistantError): + """Error to indicate the URL format is invalid.""" + + +class CannotConnect(HomeAssistantError): + """Error to indicate we cannot connect.""" + + +class TimeoutConnectError(HomeAssistantError): + """Error to indicate we cannot connect.""" + + +class InvalidAuth(HomeAssistantError): + """Error to indicate there is invalid auth.""" + + +class MissingCapabilities(HomeAssistantError): + """Error to indicate that the MCP server is missing required capabilities.""" diff --git a/homeassistant/components/mcp/const.py b/homeassistant/components/mcp/const.py new file mode 100644 index 00000000000..675b2d7031c --- /dev/null +++ b/homeassistant/components/mcp/const.py @@ -0,0 +1,3 @@ +"""Constants for the Model Context Protocol integration.""" + +DOMAIN = "mcp" diff --git a/homeassistant/components/mcp/coordinator.py b/homeassistant/components/mcp/coordinator.py new file mode 100644 index 00000000000..a5c5ee55dbf --- /dev/null +++ b/homeassistant/components/mcp/coordinator.py @@ -0,0 +1,171 @@ +"""Types for the Model Context Protocol integration.""" + +import asyncio +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +import datetime +import logging + +import httpx +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +import voluptuous as vol +from voluptuous_openapi import convert_to_voluptuous + +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_URL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import llm +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed +from homeassistant.util.json import JsonObjectType + +from .const import DOMAIN + +_LOGGER = logging.getLogger(__name__) + +UPDATE_INTERVAL = datetime.timedelta(minutes=30) +TIMEOUT = 10 + + +@asynccontextmanager +async def mcp_client(url: str) -> AsyncGenerator[ClientSession]: + """Create a server-sent event MCP client. + + This is an asynccontext manager that exists to wrap other async context managers + so that the coordinator has a single object to manage. + """ + try: + async with sse_client(url=url) as streams, ClientSession(*streams) as session: + await session.initialize() + yield session + except ExceptionGroup as err: + raise err.exceptions[0] from err + + +class ModelContextProtocolTool(llm.Tool): + """A Tool exposed over the Model Context Protocol.""" + + def __init__( + self, + name: str, + description: str | None, + parameters: vol.Schema, + session: ClientSession, + ) -> None: + """Initialize the tool.""" + self.name = name + self.description = description + self.parameters = parameters + self.session = session + + async def async_call( + self, + hass: HomeAssistant, + tool_input: llm.ToolInput, + llm_context: llm.LLMContext, + ) -> JsonObjectType: + """Call the tool.""" + try: + result = await self.session.call_tool( + tool_input.tool_name, tool_input.tool_args + ) + except httpx.HTTPStatusError as error: + raise HomeAssistantError(f"Error when calling tool: {error}") from error + return result.model_dump(exclude_unset=True, exclude_none=True) + + +class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]): + """Define an object to hold MCP data.""" + + config_entry: ConfigEntry + _session: ClientSession | None = None + _setup_error: Exception | None = None + + def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None: + """Initialize ModelContextProtocolCoordinator.""" + super().__init__( + hass, + logger=_LOGGER, + name=DOMAIN, + config_entry=config_entry, + update_interval=UPDATE_INTERVAL, + ) + self._stop = asyncio.Event() + + async def _async_setup(self) -> None: + """Set up the client connection.""" + connected = asyncio.Event() + stop = asyncio.Event() + self.config_entry.async_create_background_task( + self.hass, self._connect(connected, stop), "mcp-client" + ) + try: + async with asyncio.timeout(TIMEOUT): + await connected.wait() + self._stop = stop + finally: + if self._setup_error is not None: + raise self._setup_error + + async def _connect(self, connected: asyncio.Event, stop: asyncio.Event) -> None: + """Create a server-sent event MCP client.""" + url = self.config_entry.data[CONF_URL] + try: + async with ( + sse_client(url=url) as streams, + ClientSession(*streams) as session, + ): + await session.initialize() + self._session = session + connected.set() + await stop.wait() + except httpx.HTTPStatusError as err: + self._setup_error = err + _LOGGER.debug("Error connecting to MCP server: %s", err) + raise UpdateFailed(f"Error connecting to MCP server: {err}") from err + except ExceptionGroup as err: + self._setup_error = err.exceptions[0] + _LOGGER.debug("Error connecting to MCP server: %s", err) + raise UpdateFailed( + "Error connecting to MCP server: {err.exceptions[0]}" + ) from err.exceptions[0] + finally: + self._session = None + + async def close(self) -> None: + """Close the client connection.""" + if self._stop is not None: + self._stop.set() + + async def _async_update_data(self) -> list[llm.Tool]: + """Fetch data from API endpoint. + + This is the place to pre-process the data to lookup tables + so entities can quickly look up their data. + """ + if self._session is None: + raise UpdateFailed("No session available") + try: + result = await self._session.list_tools() + except httpx.HTTPError as err: + raise UpdateFailed(f"Error communicating with API: {err}") from err + + _LOGGER.debug("Received tools: %s", result.tools) + tools: list[llm.Tool] = [] + for tool in result.tools: + try: + parameters = convert_to_voluptuous(tool.inputSchema) + except Exception as err: + raise UpdateFailed( + f"Error converting schema {err}: {tool.inputSchema}" + ) from err + tools.append( + ModelContextProtocolTool( + tool.name, + tool.description, + parameters, + self._session, + ) + ) + return tools diff --git a/homeassistant/components/mcp/manifest.json b/homeassistant/components/mcp/manifest.json new file mode 100644 index 00000000000..ee4baf04802 --- /dev/null +++ b/homeassistant/components/mcp/manifest.json @@ -0,0 +1,10 @@ +{ + "domain": "mcp", + "name": "Model Context Protocol", + "codeowners": ["@allenporter"], + "config_flow": true, + "documentation": "https://www.home-assistant.io/integrations/mcp", + "iot_class": "local_polling", + "quality_scale": "silver", + "requirements": ["mcp==1.1.2"] +} diff --git a/homeassistant/components/mcp/quality_scale.yaml b/homeassistant/components/mcp/quality_scale.yaml new file mode 100644 index 00000000000..76afdf5860d --- /dev/null +++ b/homeassistant/components/mcp/quality_scale.yaml @@ -0,0 +1,88 @@ +rules: + # Bronze + action-setup: + status: exempt + comment: Integration does not have actions. + appropriate-polling: done + brands: done + common-modules: done + config-flow-test-coverage: done + config-flow: done + dependency-transparency: done + docs-actions: + status: exempt + comment: Integration does not have actions. + docs-high-level-description: done + docs-installation-instructions: done + docs-removal-instructions: done + entity-event-setup: + status: exempt + comment: Integration does not have entities. + entity-unique-id: + status: exempt + comment: Integration does not have entities. + has-entity-name: + status: exempt + comment: Integration does not have entities. + runtime-data: done + test-before-configure: done + test-before-setup: done + unique-config-entry: done + + # Silver + action-exceptions: + status: exempt + comment: Integration does not have actions. + config-entry-unloading: done + docs-configuration-parameters: done + docs-installation-parameters: done + entity-unavailable: + status: exempt + comment: Integration does not have entities. + integration-owner: done + log-when-unavailable: done + parallel-updates: + status: exempt + comment: Integration does not have platforms. + reauthentication-flow: + status: exempt + comment: Integration does not support authentication. + test-coverage: done + + # Gold + devices: + status: exempt + comment: Integration does not have devices. + diagnostics: todo + discovery-update-info: todo + discovery: todo + docs-data-update: done + docs-examples: done + docs-known-limitations: done + docs-supported-devices: done + docs-supported-functions: done + docs-troubleshooting: done + docs-use-cases: done + dynamic-devices: todo + entity-category: + status: exempt + comment: Integration does not have entities. + entity-device-class: + status: exempt + comment: Integration does not have entities. + entity-disabled-by-default: + status: exempt + comment: Integration does not have entities. + entity-translations: + status: exempt + comment: Integration does not have entities. + exception-translations: todo + icon-translations: todo + reconfiguration-flow: todo + repair-issues: todo + stale-devices: todo + + # Platinum + async-dependency: done + inject-websession: todo + strict-typing: done diff --git a/homeassistant/components/mcp/strings.json b/homeassistant/components/mcp/strings.json new file mode 100644 index 00000000000..97a75fc6f85 --- /dev/null +++ b/homeassistant/components/mcp/strings.json @@ -0,0 +1,25 @@ +{ + "config": { + "step": { + "user": { + "data": { + "url": "[%key:common::config_flow::data::url%]" + }, + "data_description": { + "url": "The remote MCP server URL for the SSE endpoint, for example http://example/sse" + } + } + }, + "error": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "unknown": "[%key:common::config_flow::error::unknown%]", + "timeout_connect": "[%key:common::config_flow::error::timeout_connect%]", + "invalid_url": "Must be a valid MCP server URL e.g. https://example.com/sse" + }, + "abort": { + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", + "missing_capabilities": "The MCP server does not support a required capability (Tools)", + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + } + } +} diff --git a/homeassistant/components/mcp/types.py b/homeassistant/components/mcp/types.py new file mode 100644 index 00000000000..961c9ab3d18 --- /dev/null +++ b/homeassistant/components/mcp/types.py @@ -0,0 +1,7 @@ +"""Types for the Model Context Protocol integration.""" + +from homeassistant.config_entries import ConfigEntry + +from .coordinator import ModelContextProtocolCoordinator + +type ModelContextProtocolConfigEntry = ConfigEntry[ModelContextProtocolCoordinator] diff --git a/homeassistant/generated/config_flows.py b/homeassistant/generated/config_flows.py index b393e5c8851..7dea4598790 100644 --- a/homeassistant/generated/config_flows.py +++ b/homeassistant/generated/config_flows.py @@ -358,6 +358,7 @@ FLOWS = { "mailgun", "mastodon", "matter", + "mcp", "mcp_server", "mealie", "meater", diff --git a/homeassistant/generated/integrations.json b/homeassistant/generated/integrations.json index 9a7167f5367..6d2e784c583 100644 --- a/homeassistant/generated/integrations.json +++ b/homeassistant/generated/integrations.json @@ -3607,6 +3607,12 @@ "config_flow": true, "iot_class": "local_push" }, + "mcp": { + "name": "Model Context Protocol", + "integration_type": "hub", + "config_flow": true, + "iot_class": "local_polling" + }, "mcp_server": { "name": "Model Context Protocol Server", "integration_type": "service", diff --git a/mypy.ini b/mypy.ini index 7f7b66e238f..188f1f7bbd7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2916,6 +2916,16 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.mcp.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.mcp_server.*] check_untyped_defs = true disallow_incomplete_defs = true diff --git a/requirements_all.txt b/requirements_all.txt index 80890e8b612..87580b45ca9 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1364,6 +1364,7 @@ maxcube-api==0.4.3 # homeassistant.components.mythicbeastsdns mbddns==0.1.2 +# homeassistant.components.mcp # homeassistant.components.mcp_server mcp==1.1.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index a3bc80b736b..2894749732e 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1142,6 +1142,7 @@ maxcube-api==0.4.3 # homeassistant.components.mythicbeastsdns mbddns==0.1.2 +# homeassistant.components.mcp # homeassistant.components.mcp_server mcp==1.1.2 diff --git a/tests/components/mcp/__init__.py b/tests/components/mcp/__init__.py new file mode 100644 index 00000000000..e8e8635ab36 --- /dev/null +++ b/tests/components/mcp/__init__.py @@ -0,0 +1 @@ +"""Tests for the Model Context Protocol integration.""" diff --git a/tests/components/mcp/conftest.py b/tests/components/mcp/conftest.py new file mode 100644 index 00000000000..d86603a12ed --- /dev/null +++ b/tests/components/mcp/conftest.py @@ -0,0 +1,45 @@ +"""Common fixtures for the Model Context Protocol tests.""" + +from collections.abc import Generator +from unittest.mock import AsyncMock, patch + +import pytest + +from homeassistant.components.mcp.const import DOMAIN +from homeassistant.const import CONF_URL +from homeassistant.core import HomeAssistant + +from tests.common import MockConfigEntry + +TEST_API_NAME = "Memory Server" + + +@pytest.fixture +def mock_setup_entry() -> Generator[AsyncMock]: + """Override async_setup_entry.""" + with patch( + "homeassistant.components.mcp.async_setup_entry", return_value=True + ) as mock_setup_entry: + yield mock_setup_entry + + +@pytest.fixture +def mock_mcp_client() -> Generator[AsyncMock]: + """Fixture to mock the MCP client.""" + with ( + patch("homeassistant.components.mcp.coordinator.sse_client"), + patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session, + ): + yield mock_session.return_value.__aenter__ + + +@pytest.fixture(name="config_entry") +def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: + """Fixture to load the integration.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_URL: "http://1.1.1.1/sse"}, + title=TEST_API_NAME, + ) + config_entry.add_to_hass(hass) + return config_entry diff --git a/tests/components/mcp/test_config_flow.py b/tests/components/mcp/test_config_flow.py new file mode 100644 index 00000000000..29733e653a6 --- /dev/null +++ b/tests/components/mcp/test_config_flow.py @@ -0,0 +1,234 @@ +"""Test the Model Context Protocol config flow.""" + +from typing import Any +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest + +from homeassistant import config_entries +from homeassistant.components.mcp.const import DOMAIN +from homeassistant.const import CONF_URL +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType + +from .conftest import TEST_API_NAME + +from tests.common import MockConfigEntry + + +async def test_form( + hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock +) -> None: + """Test the complete configuration flow.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {} + + response = Mock() + response.serverInfo.name = TEST_API_NAME + mock_mcp_client.return_value.initialize.return_value = response + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == TEST_API_NAME + assert result["data"] == { + CONF_URL: "http://1.1.1.1/sse", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +@pytest.mark.parametrize( + ("side_effect", "expected_error"), + [ + (httpx.TimeoutException("Some timeout"), "timeout_connect"), + ( + httpx.HTTPStatusError("", request=None, response=httpx.Response(500)), + "cannot_connect", + ), + (httpx.HTTPError("Some HTTP error"), "cannot_connect"), + (Exception, "unknown"), + ], +) +async def test_form_mcp_client_error( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + mock_mcp_client: Mock, + side_effect: Exception, + expected_error: str, +) -> None: + """Test we handle different client library errors.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + mock_mcp_client.side_effect = side_effect + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {"base": expected_error} + + # Reset the error and make sure the config flow can resume successfully. + mock_mcp_client.side_effect = None + response = Mock() + response.serverInfo.name = TEST_API_NAME + mock_mcp_client.return_value.initialize.return_value = response + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == TEST_API_NAME + assert result["data"] == { + CONF_URL: "http://1.1.1.1/sse", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +@pytest.mark.parametrize( + ("side_effect", "expected_error"), + [ + ( + httpx.HTTPStatusError("", request=None, response=httpx.Response(401)), + "invalid_auth", + ), + ], +) +async def test_form_mcp_client_error_abort( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + mock_mcp_client: Mock, + side_effect: Exception, + expected_error: str, +) -> None: + """Test we handle different client library errors that end with an abort.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + mock_mcp_client.side_effect = side_effect + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == expected_error + + +@pytest.mark.parametrize( + "user_input", + [ + ({CONF_URL: "not a url"}), + ({CONF_URL: "rtsp://1.1.1.1"}), + ], +) +async def test_input_form_validation_error( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + mock_mcp_client: Mock, + user_input: dict[str, Any], +) -> None: + """Test we handle invalid auth.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input, + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {CONF_URL: "invalid_url"} + + # Reset the error and make sure the config flow can resume successfully. + response = Mock() + response.serverInfo.name = TEST_API_NAME + mock_mcp_client.return_value.initialize.return_value = response + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == TEST_API_NAME + assert result["data"] == { + CONF_URL: "http://1.1.1.1/sse", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_unique_url( + hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock +) -> None: + """Test that the same url cannot be configured twice.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_URL: "http://1.1.1.1/sse"}, + title=TEST_API_NAME, + ) + config_entry.add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {} + + response = Mock() + response.serverInfo.name = TEST_API_NAME + mock_mcp_client.return_value.initialize.return_value = response + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "already_configured" + + +async def test_server_missing_capbilities( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + mock_mcp_client: Mock, +) -> None: + """Test we handle different client library errors.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + response = Mock() + response.serverInfo.name = TEST_API_NAME + response.capabilities.tools = None + mock_mcp_client.return_value.initialize.return_value = response + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_URL: "http://1.1.1.1/sse", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "missing_capabilities" diff --git a/tests/components/mcp/test_init.py b/tests/components/mcp/test_init.py new file mode 100644 index 00000000000..460df2c5785 --- /dev/null +++ b/tests/components/mcp/test_init.py @@ -0,0 +1,225 @@ +"""Tests for the Model Context Protocol component.""" + +import re +from unittest.mock import Mock, patch + +import httpx +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +import pytest +import voluptuous as vol + +from homeassistant.config_entries import ConfigEntryState +from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import llm + +from .conftest import TEST_API_NAME + +from tests.common import MockConfigEntry + +SEARCH_MEMORY_TOOL = Tool( + name="search_memory", + description="Search memory for relevant context based on a query.", + inputSchema={ + "type": "object", + "required": ["query"], + "properties": { + "query": { + "type": "string", + "description": "A free text query to search context for.", + } + }, + }, +) +SAVE_MEMORY_TOOL = Tool( + name="save_memory", + description="Save a memory context.", + inputSchema={ + "type": "object", + "required": ["context"], + "properties": { + "context": { + "type": "object", + "description": "The context to save.", + "properties": { + "fact": { + "type": "string", + "description": "The key for the context.", + }, + }, + }, + }, + }, +) + + +def create_llm_context() -> llm.LLMContext: + """Create a test LLM context.""" + return llm.LLMContext( + platform="test_platform", + context=Context(), + user_prompt="test_text", + language="*", + assistant="conversation", + device_id=None, + ) + + +async def test_init( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test the integration is initialized and can be unloaded cleanly.""" + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.NOT_LOADED + + +async def test_mcp_server_failure( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test the integration fails to setup if the server fails initialization.""" + mock_mcp_client.side_effect = httpx.HTTPStatusError( + "", request=None, response=httpx.Response(500) + ) + + with patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1): + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_list_tools_failure( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test the integration fails to load if the first data fetch returns an error.""" + mock_mcp_client.return_value.list_tools.side_effect = httpx.HTTPStatusError( + "", request=None, response=httpx.Response(500) + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_llm_get_api_tools( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test MCP tools are returned as LLM API tools.""" + mock_mcp_client.return_value.list_tools.return_value = ListToolsResult( + tools=[SEARCH_MEMORY_TOOL, SAVE_MEMORY_TOOL], + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + apis = llm.async_get_apis(hass) + api = next(iter([api for api in apis if api.name == TEST_API_NAME])) + assert api + + api_instance = await api.async_get_api_instance(create_llm_context()) + assert len(api_instance.tools) == 2 + tool = api_instance.tools[0] + assert tool.name == "search_memory" + assert tool.description == "Search memory for relevant context based on a query." + with pytest.raises( + vol.Invalid, match=re.escape("required key not provided @ data['query']") + ): + tool.parameters({}) + assert tool.parameters({"query": "frogs"}) == {"query": "frogs"} + + tool = api_instance.tools[1] + assert tool.name == "save_memory" + assert tool.description == "Save a memory context." + with pytest.raises( + vol.Invalid, match=re.escape("required key not provided @ data['context']") + ): + tool.parameters({}) + assert tool.parameters({"context": {"fact": "User was born in February"}}) == { + "context": {"fact": "User was born in February"} + } + + +async def test_call_tool( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test calling an MCP Tool through the LLM API.""" + mock_mcp_client.return_value.list_tools.return_value = ListToolsResult( + tools=[SEARCH_MEMORY_TOOL] + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + apis = llm.async_get_apis(hass) + api = next(iter([api for api in apis if api.name == TEST_API_NAME])) + assert api + + api_instance = await api.async_get_api_instance(create_llm_context()) + assert len(api_instance.tools) == 1 + tool = api_instance.tools[0] + assert tool.name == "search_memory" + + mock_mcp_client.return_value.call_tool.return_value = CallToolResult( + content=[TextContent(type="text", text="User was born in February")] + ) + result = await tool.async_call( + hass, + llm.ToolInput( + tool_name="search_memory", tool_args={"query": "User's birth month"} + ), + create_llm_context(), + ) + assert result == { + "content": [{"text": "User was born in February", "type": "text"}] + } + + +async def test_call_tool_fails( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test handling an MCP Tool call failure.""" + mock_mcp_client.return_value.list_tools.return_value = ListToolsResult( + tools=[SEARCH_MEMORY_TOOL] + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + apis = llm.async_get_apis(hass) + api = next(iter([api for api in apis if api.name == TEST_API_NAME])) + assert api + + api_instance = await api.async_get_api_instance(create_llm_context()) + assert len(api_instance.tools) == 1 + tool = api_instance.tools[0] + assert tool.name == "search_memory" + + mock_mcp_client.return_value.call_tool.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=httpx.Response(500) + ) + with pytest.raises( + HomeAssistantError, match="Error when calling tool: Server error" + ): + await tool.async_call( + hass, + llm.ToolInput( + tool_name="search_memory", tool_args={"query": "User's birth month"} + ), + create_llm_context(), + ) + + +async def test_convert_tool_schema_fails( + hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock +) -> None: + """Test a failure converting an MCP tool schema to a Home Assistant schema.""" + mock_mcp_client.return_value.list_tools.return_value = ListToolsResult( + tools=[SEARCH_MEMORY_TOOL] + ) + + with patch( + "homeassistant.components.mcp.coordinator.convert_to_voluptuous", + side_effect=ValueError, + ): + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.SETUP_RETRY