mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Refactor conversation agent WS API for listing agents (#91590)
* Refactor conversation agent WS API for listing agents * Add conversation/agent/info back
This commit is contained in:
parent
9c784ac622
commit
5e9bbeb4ad
@ -337,7 +337,7 @@ class PipelineRun:
|
||||
message=f"Intent recognition engine {engine} is not found",
|
||||
)
|
||||
|
||||
self.intent_agent = agent_info["id"]
|
||||
self.intent_agent = agent_info.id
|
||||
|
||||
async def recognize_intent(
|
||||
self, intent_input: str, conversation_id: str | None
|
||||
|
@ -2,9 +2,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -16,6 +17,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv, intent, singleton
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .const import HOME_ASSISTANT_AGENT
|
||||
@ -229,24 +231,29 @@ async def websocket_get_agent_info(
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "conversation/agent/list",
|
||||
vol.Optional("language"): str,
|
||||
}
|
||||
)
|
||||
@core.callback
|
||||
def websocket_list_agents(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
@websocket_api.async_response
|
||||
async def websocket_list_agents(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List available agents."""
|
||||
"""List conversation agents and, optionally, if they support a given language."""
|
||||
manager = _get_agent_manager(hass)
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"default_agent": manager.default_agent,
|
||||
"agents": manager.async_get_agent_info(),
|
||||
},
|
||||
)
|
||||
language = msg.get("language")
|
||||
agents = []
|
||||
|
||||
for agent_info in manager.async_get_agent_info():
|
||||
agent_dict: dict[str, Any] = {"id": agent_info.id, "name": agent_info.name}
|
||||
if language:
|
||||
agent = await manager.async_get_agent(agent_info.id)
|
||||
agent_dict["language_supported"] = bool(
|
||||
language_util.matches(language, agent.supported_languages)
|
||||
)
|
||||
agents.append(agent_dict)
|
||||
|
||||
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
|
||||
|
||||
|
||||
class ConversationProcessView(http.HomeAssistantView):
|
||||
@ -281,8 +288,9 @@ class ConversationProcessView(http.HomeAssistantView):
|
||||
return self.json(result.as_dict())
|
||||
|
||||
|
||||
class AgentInfo(TypedDict):
|
||||
"""Dictionary holding agent info."""
|
||||
@dataclass(frozen=True)
|
||||
class AgentInfo:
|
||||
"""Container for conversation agent info."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
@ -300,7 +308,7 @@ def async_get_agent_info(
|
||||
agent_id = manager.default_agent
|
||||
|
||||
for agent_info in manager.async_get_agent_info():
|
||||
if agent_info["id"] == agent_id:
|
||||
if agent_info.id == agent_id:
|
||||
return agent_info
|
||||
|
||||
return None
|
||||
@ -375,10 +383,10 @@ class AgentManager:
|
||||
def async_get_agent_info(self) -> list[AgentInfo]:
|
||||
"""List all agents."""
|
||||
agents: list[AgentInfo] = [
|
||||
{
|
||||
"id": HOME_ASSISTANT_AGENT,
|
||||
"name": "Home Assistant",
|
||||
}
|
||||
AgentInfo(
|
||||
id=HOME_ASSISTANT_AGENT,
|
||||
name="Home Assistant",
|
||||
)
|
||||
]
|
||||
for agent_id, agent in self._agents.items():
|
||||
config_entry = self.hass.config_entries.async_get_entry(agent_id)
|
||||
@ -393,10 +401,10 @@ class AgentManager:
|
||||
continue
|
||||
|
||||
agents.append(
|
||||
{
|
||||
"id": agent_id,
|
||||
"name": config_entry.title,
|
||||
}
|
||||
AgentInfo(
|
||||
id=agent_id,
|
||||
name=config_entry.title,
|
||||
)
|
||||
)
|
||||
return agents
|
||||
|
||||
|
@ -29,6 +29,53 @@
|
||||
'name': 'Mock Title',
|
||||
}),
|
||||
]),
|
||||
'default_agent': 'mock-entry',
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_list.1
|
||||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'language_supported': False,
|
||||
'name': 'Home Assistant',
|
||||
}),
|
||||
dict({
|
||||
'id': 'mock-entry',
|
||||
'language_supported': True,
|
||||
'name': 'Mock Title',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_list.2
|
||||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'language_supported': True,
|
||||
'name': 'Home Assistant',
|
||||
}),
|
||||
dict({
|
||||
'id': 'mock-entry',
|
||||
'language_supported': False,
|
||||
'name': 'Mock Title',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_list.3
|
||||
dict({
|
||||
'agents': list([
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'language_supported': True,
|
||||
'name': 'Home Assistant',
|
||||
}),
|
||||
dict({
|
||||
'id': 'mock-entry',
|
||||
'language_supported': False,
|
||||
'name': 'Mock Title',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
|
@ -1582,10 +1582,32 @@ async def test_get_agent_list(
|
||||
"""Test getting agent info."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json({"id": 5, "type": "conversation/agent/list"})
|
||||
|
||||
await client.send_json_auto_id({"type": "conversation/agent/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["type"] == "result"
|
||||
assert msg["success"]
|
||||
assert msg["result"] == snapshot
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{"type": "conversation/agent/list", "language": "smurfish"}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["type"] == "result"
|
||||
assert msg["success"]
|
||||
assert msg["result"] == snapshot
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{"type": "conversation/agent/list", "language": "en"}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["type"] == "result"
|
||||
assert msg["success"]
|
||||
assert msg["result"] == snapshot
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{"type": "conversation/agent/list", "language": "en-UK"}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["id"] == 5
|
||||
assert msg["type"] == "result"
|
||||
assert msg["success"]
|
||||
assert msg["result"] == snapshot
|
||||
@ -1597,7 +1619,7 @@ async def test_get_agent_info(
|
||||
"""Test get agent info."""
|
||||
agent_info = conversation.async_get_agent_info(hass)
|
||||
# Test it's the default
|
||||
assert agent_info["id"] == mock_agent.agent_id
|
||||
assert agent_info.id == mock_agent.agent_id
|
||||
assert agent_info == snapshot
|
||||
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
||||
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user