mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Add prompts to MCP server (#134619)
* Add prompts to MCP server * Improve test coverage for get prompt error cases
This commit is contained in:
parent
c9a607aa45
commit
bb97a16756
@ -50,6 +50,37 @@ async def create_server(
|
|||||||
|
|
||||||
server = Server("home-assistant")
|
server = Server("home-assistant")
|
||||||
|
|
||||||
|
@server.list_prompts() # type: ignore[no-untyped-call, misc]
|
||||||
|
async def handle_list_prompts() -> list[types.Prompt]:
|
||||||
|
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||||
|
return [
|
||||||
|
types.Prompt(
|
||||||
|
name=llm_api.api.name,
|
||||||
|
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@server.get_prompt() # type: ignore[no-untyped-call, misc]
|
||||||
|
async def handle_get_prompt(
|
||||||
|
name: str, arguments: dict[str, str] | None
|
||||||
|
) -> types.GetPromptResult:
|
||||||
|
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
|
||||||
|
if name != llm_api.api.name:
|
||||||
|
raise ValueError(f"Unknown prompt: {name}")
|
||||||
|
|
||||||
|
return types.GetPromptResult(
|
||||||
|
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
|
||||||
|
messages=[
|
||||||
|
types.PromptMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=llm_api.api_prompt,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
@server.list_tools() # type: ignore[no-untyped-call, misc]
|
||||||
async def list_tools() -> list[types.Tool]:
|
async def list_tools() -> list[types.Tool]:
|
||||||
"""List available time tools."""
|
"""List available time tools."""
|
||||||
|
@ -10,6 +10,7 @@ import aiohttp
|
|||||||
import mcp
|
import mcp
|
||||||
import mcp.client.session
|
import mcp.client.session
|
||||||
import mcp.client.sse
|
import mcp.client.sse
|
||||||
|
from mcp.shared.exceptions import McpError
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
|
||||||
@ -354,3 +355,51 @@ async def test_mcp_tool_call_failed(
|
|||||||
assert len(result.content) == 1
|
assert len(result.content) == 1
|
||||||
assert result.content[0].type == "text"
|
assert result.content[0].type == "text"
|
||||||
assert "Error calling tool" in result.content[0].text
|
assert "Error calling tool" in result.content[0].text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_prompt_list(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
setup_integration: None,
|
||||||
|
mcp_sse_url: str,
|
||||||
|
hass_supervisor_access_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test the list prompt endpoint."""
|
||||||
|
|
||||||
|
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||||
|
result = await session.list_prompts()
|
||||||
|
|
||||||
|
assert len(result.prompts) == 1
|
||||||
|
prompt = result.prompts[0]
|
||||||
|
assert prompt.name == "Assist"
|
||||||
|
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_prompt_get(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
setup_integration: None,
|
||||||
|
mcp_sse_url: str,
|
||||||
|
hass_supervisor_access_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test the get prompt endpoint."""
|
||||||
|
|
||||||
|
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||||
|
result = await session.get_prompt(name="Assist")
|
||||||
|
|
||||||
|
assert result.description == "Default prompt for the Home Assistant LLM API Assist"
|
||||||
|
assert len(result.messages) == 1
|
||||||
|
assert result.messages[0].role == "assistant"
|
||||||
|
assert result.messages[0].content.type == "text"
|
||||||
|
assert "When controlling Home Assistant" in result.messages[0].content.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_unknwon_prompt(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
setup_integration: None,
|
||||||
|
mcp_sse_url: str,
|
||||||
|
hass_supervisor_access_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test the get prompt endpoint."""
|
||||||
|
|
||||||
|
async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
|
||||||
|
with pytest.raises(McpError):
|
||||||
|
await session.get_prompt(name="Unknown")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user