mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 11:47:06 +00:00
Update MCP server to make the stateless API implicit (#140753)
* Update MCP server to not register the stateless API, but use it implicitly as an Assist API replacement * Ensure backwards compatibility with old registration
This commit is contained in:
parent
c9276aedde
commit
412705302d
@ -6,7 +6,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from . import http, llm_api
|
from . import http
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .session import SessionManager
|
from .session import SessionManager
|
||||||
from .types import MCPServerConfigEntry
|
from .types import MCPServerConfigEntry
|
||||||
@ -25,7 +25,6 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
|||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up the Model Context Protocol component."""
|
"""Set up the Model Context Protocol component."""
|
||||||
http.async_register(hass)
|
http.async_register(hass)
|
||||||
llm_api.async_register_api(hass)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from homeassistant.helpers.selector import (
|
|||||||
SelectSelectorConfig,
|
SelectSelectorConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .const import DOMAIN, LLM_API, LLM_API_NAME
|
from .const import DOMAIN
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -33,13 +33,6 @@ class ModelContextServerProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Handle the initial step."""
|
"""Handle the initial step."""
|
||||||
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
llm_apis = {api.id: api.name for api in llm.async_get_apis(self.hass)}
|
||||||
if LLM_API not in llm_apis:
|
|
||||||
# MCP server component is not loaded yet, so make the LLM API a choice.
|
|
||||||
llm_apis = {
|
|
||||||
LLM_API: LLM_API_NAME,
|
|
||||||
**llm_apis,
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input
|
title=llm_apis[user_input[CONF_LLM_HASS_API]], data=user_input
|
||||||
|
@ -2,5 +2,6 @@
|
|||||||
|
|
||||||
DOMAIN = "mcp_server"
|
DOMAIN = "mcp_server"
|
||||||
TITLE = "Model Context Protocol Server"
|
TITLE = "Model Context Protocol Server"
|
||||||
LLM_API = "stateless_assist"
|
# The Stateless API is no longer registered explicitly, but this name may still exist in the
|
||||||
LLM_API_NAME = "Stateless Assist"
|
# users config entry.
|
||||||
|
STATELESS_LLM_API = "stateless_assist"
|
||||||
|
@ -1,19 +1,18 @@
|
|||||||
"""LLM API for MCP Server."""
|
"""LLM API for MCP Server.
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
This is a modified version of the AssistAPI that does not include the home state
|
||||||
|
in the prompt. This API is not registered with the LLM API registry since it is
|
||||||
|
only used by the MCP Server. The MCP server will substitute this API when the
|
||||||
|
user selects the Assist API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from homeassistant.core import callback
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.util import yaml as yaml_util
|
from homeassistant.util import yaml as yaml_util
|
||||||
|
|
||||||
from .const import LLM_API, LLM_API_NAME
|
|
||||||
|
|
||||||
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
EXPOSED_ENTITY_FIELDS = {"name", "domain", "description", "areas", "names"}
|
||||||
|
|
||||||
|
|
||||||
def async_register_api(hass: HomeAssistant) -> None:
|
|
||||||
"""Register the LLM API."""
|
|
||||||
llm.async_register_api(hass, StatelessAssistAPI(hass))
|
|
||||||
|
|
||||||
|
|
||||||
class StatelessAssistAPI(llm.AssistAPI):
|
class StatelessAssistAPI(llm.AssistAPI):
|
||||||
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
"""LLM API for MCP Server that provides the Assist API without state information in the prompt.
|
||||||
|
|
||||||
@ -22,12 +21,6 @@ class StatelessAssistAPI(llm.AssistAPI):
|
|||||||
actions don't care about the current state, there is little quality loss.
|
actions don't care about the current state, there is little quality loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
|
||||||
"""Initialize the StatelessAssistAPI."""
|
|
||||||
super().__init__(hass)
|
|
||||||
self.id = LLM_API
|
|
||||||
self.name = LLM_API_NAME
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_exposed_entities_prompt(
|
def _async_get_exposed_entities_prompt(
|
||||||
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
self, llm_context: llm.LLMContext, exposed_entities: dict | None
|
||||||
|
@ -21,6 +21,9 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
|
from .const import STATELESS_LLM_API
|
||||||
|
from .llm_api import StatelessAssistAPI
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -50,13 +53,21 @@ async def create_server(
|
|||||||
|
|
||||||
server = Server("home-assistant")
|
server = Server("home-assistant")
|
||||||
|
|
||||||
|
async def get_api_instance() -> llm.APIInstance:
|
||||||
|
"""Substitute the StatelessAssistAPI for the Assist API if selected."""
|
||||||
|
if llm_api_id in (STATELESS_LLM_API, llm.LLM_API_ASSIST):
|
||||||
|
api = StatelessAssistAPI(hass)
|
||||||
|
return await api.async_get_api_instance(llm_context)
|
||||||
|
|
||||||
|
return await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||||
|
|
||||||
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
||||||
async def handle_list_prompts() -> list[types.Prompt]:
|
async def handle_list_prompts() -> list[types.Prompt]:
|
||||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
llm_api = await get_api_instance()
|
||||||
return [
|
return [
|
||||||
types.Prompt(
|
types.Prompt(
|
||||||
name=llm_api.api.name,
|
name=llm_api.api.name,
|
||||||
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -64,12 +75,12 @@ async def create_server(
|
|||||||
async def handle_get_prompt(
|
async def handle_get_prompt(
|
||||||
name: str, arguments: dict[str, str] | None
|
name: str, arguments: dict[str, str] | None
|
||||||
) -> types.GetPromptResult:
|
) -> types.GetPromptResult:
|
||||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
llm_api = await get_api_instance()
|
||||||
if name != llm_api.api.name:
|
if name != llm_api.api.name:
|
||||||
raise ValueError(f"Unknown prompt: {name}")
|
raise ValueError(f"Unknown prompt: {name}")
|
||||||
|
|
||||||
return types.GetPromptResult(
|
return types.GetPromptResult(
|
||||||
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
|
||||||
messages=[
|
messages=[
|
||||||
types.PromptMessage(
|
types.PromptMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@ -84,13 +95,13 @@ async def create_server(
|
|||||||
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
||||||
async def list_tools() -> list[types.Tool]:
|
async def list_tools() -> list[types.Tool]:
|
||||||
"""List available time tools."""
|
"""List available time tools."""
|
||||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
llm_api = await get_api_instance()
|
||||||
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
||||||
|
|
||||||
@server.call_tool() # type: ignore[no-untyped-call, misc]
|
@server.call_tool() # type: ignore[no-untyped-call, misc]
|
||||||
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
|
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
|
||||||
"""Handle calling tools."""
|
"""Handle calling tools."""
|
||||||
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
llm_api = await get_api_instance()
|
||||||
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
|
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
|
||||||
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
||||||
|
|
||||||
|
@ -5,9 +5,10 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.mcp_server.const import DOMAIN, LLM_API
|
from homeassistant.components.mcp_server.const import DOMAIN
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -21,13 +22,19 @@ def mock_setup_entry() -> Generator[AsyncMock]:
|
|||||||
yield mock_setup_entry
|
yield mock_setup_entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llm_hass_api")
|
||||||
|
def llm_hass_api_fixture() -> str:
|
||||||
|
"""Fixture for the config entry llm_hass_api."""
|
||||||
|
return llm.LLM_API_ASSIST
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="config_entry")
|
@pytest.fixture(name="config_entry")
|
||||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
def mock_config_entry(hass: HomeAssistant, llm_hass_api: str) -> MockConfigEntry:
|
||||||
"""Fixture to load the integration."""
|
"""Fixture to load the integration."""
|
||||||
config_entry = MockConfigEntry(
|
config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={
|
data={
|
||||||
CONF_LLM_HASS_API: LLM_API,
|
CONF_LLM_HASS_API: llm_hass_api,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
|
@ -16,6 +16,7 @@ import pytest
|
|||||||
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||||
|
from homeassistant.components.mcp_server.const import STATELESS_LLM_API
|
||||||
from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
|
from homeassistant.components.mcp_server.http import MESSAGES_API, SSE_API
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
from homeassistant.const import CONF_LLM_HASS_API, STATE_OFF, STATE_ON
|
||||||
@ -24,6 +25,7 @@ from homeassistant.helpers import (
|
|||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
|
llm,
|
||||||
)
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
@ -297,6 +299,7 @@ async def mcp_session(
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||||
async def test_mcp_tools_list(
|
async def test_mcp_tools_list(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_integration: None,
|
setup_integration: None,
|
||||||
@ -319,6 +322,7 @@ async def test_mcp_tools_list(
|
|||||||
assert properties.get("name") == {"type": "string"}
|
assert properties.get("name") == {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||||
async def test_mcp_tool_call(
|
async def test_mcp_tool_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_integration: None,
|
setup_integration: None,
|
||||||
@ -371,6 +375,7 @@ async def test_mcp_tool_call_failed(
|
|||||||
assert "Error calling tool" in result.content[0].text
|
assert "Error calling tool" in result.content[0].text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||||
async def test_prompt_list(
|
async def test_prompt_list(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_integration: None,
|
setup_integration: None,
|
||||||
@ -384,13 +389,11 @@ async def test_prompt_list(
|
|||||||
|
|
||||||
assert len(result.prompts) == 1
|
assert len(result.prompts) == 1
|
||||||
prompt = result.prompts[0]
|
prompt = result.prompts[0]
|
||||||
assert prompt.name == "Stateless Assist"
|
assert prompt.name == "Assist"
|
||||||
assert (
|
assert prompt.description == "Default prompt for Home Assistant Assist API"
|
||||||
prompt.description
|
|
||||||
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("llm_hass_api", [llm.LLM_API_ASSIST, STATELESS_LLM_API])
|
||||||
async def test_prompt_get(
|
async def test_prompt_get(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_integration: None,
|
setup_integration: None,
|
||||||
@ -400,12 +403,9 @@ async def test_prompt_get(
|
|||||||
"""Test the get prompt endpoint."""
|
"""Test the get prompt endpoint."""
|
||||||
|
|
||||||
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||||
result = await session.get_prompt(name="Stateless Assist")
|
result = await session.get_prompt(name="Assist")
|
||||||
|
|
||||||
assert (
|
assert result.description == "Default prompt for Home Assistant Assist API"
|
||||||
result.description
|
|
||||||
== "Default prompt for the Home Assistant LLM API Stateless Assist"
|
|
||||||
)
|
|
||||||
assert len(result.messages) == 1
|
assert len(result.messages) == 1
|
||||||
assert result.messages[0].role == "assistant"
|
assert result.messages[0].role == "assistant"
|
||||||
assert result.messages[0].content.type == "text"
|
assert result.messages[0].content.type == "text"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user