Refactor conversation agent WS API for listing agents (#91590)

* Refactor conversation agent WS API for listing agents

* Add conversation/agent/info back
This commit is contained in:
Erik Montnemery 2023-04-19 16:53:24 +02:00 committed by GitHub
parent 9c784ac622
commit 5e9bbeb4ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 31 deletions

View File

@ -337,7 +337,7 @@ class PipelineRun:
message=f"Intent recognition engine {engine} is not found",
)
self.intent_agent = agent_info["id"]
self.intent_agent = agent_info.id
async def recognize_intent(
self, intent_input: str, conversation_id: str | None

View File

@ -2,9 +2,10 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import logging
import re
from typing import Any, TypedDict
from typing import Any
import voluptuous as vol
@ -16,6 +17,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent, singleton
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util import language as language_util
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import HOME_ASSISTANT_AGENT
@ -229,24 +231,29 @@ async def websocket_get_agent_info(
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/list",
vol.Optional("language"): str,
}
)
@core.callback
def websocket_list_agents(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
@websocket_api.async_response
async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List available agents."""
"""List conversation agents and, optionally, if they support a given language."""
manager = _get_agent_manager(hass)
connection.send_result(
msg["id"],
{
"default_agent": manager.default_agent,
"agents": manager.async_get_agent_info(),
},
)
language = msg.get("language")
agents = []
for agent_info in manager.async_get_agent_info():
agent_dict: dict[str, Any] = {"id": agent_info.id, "name": agent_info.name}
if language:
agent = await manager.async_get_agent(agent_info.id)
agent_dict["language_supported"] = bool(
language_util.matches(language, agent.supported_languages)
)
agents.append(agent_dict)
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
class ConversationProcessView(http.HomeAssistantView):
@ -281,8 +288,9 @@ class ConversationProcessView(http.HomeAssistantView):
return self.json(result.as_dict())
class AgentInfo(TypedDict):
"""Dictionary holding agent info."""
@dataclass(frozen=True)
class AgentInfo:
"""Container for conversation agent info."""
id: str
name: str
@ -300,7 +308,7 @@ def async_get_agent_info(
agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info():
if agent_info["id"] == agent_id:
if agent_info.id == agent_id:
return agent_info
return None
@ -375,10 +383,10 @@ class AgentManager:
def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents."""
agents: list[AgentInfo] = [
{
"id": HOME_ASSISTANT_AGENT,
"name": "Home Assistant",
}
AgentInfo(
id=HOME_ASSISTANT_AGENT,
name="Home Assistant",
)
]
for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id)
@ -393,10 +401,10 @@ class AgentManager:
continue
agents.append(
{
"id": agent_id,
"name": config_entry.title,
}
AgentInfo(
id=agent_id,
name=config_entry.title,
)
)
return agents

View File

@ -29,6 +29,53 @@
'name': 'Mock Title',
}),
]),
'default_agent': 'mock-entry',
})
# ---
# name: test_get_agent_list.1
dict({
'agents': list([
dict({
'id': 'homeassistant',
'language_supported': False,
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'language_supported': True,
'name': 'Mock Title',
}),
]),
})
# ---
# name: test_get_agent_list.2
dict({
'agents': list([
dict({
'id': 'homeassistant',
'language_supported': True,
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'language_supported': False,
'name': 'Mock Title',
}),
]),
})
# ---
# name: test_get_agent_list.3
dict({
'agents': list([
dict({
'id': 'homeassistant',
'language_supported': True,
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'language_supported': False,
'name': 'Mock Title',
}),
]),
})
# ---

View File

@ -1582,10 +1582,32 @@ async def test_get_agent_list(
"""Test getting agent info."""
client = await hass_ws_client(hass)
await client.send_json({"id": 5, "type": "conversation/agent/list"})
await client.send_json_auto_id({"type": "conversation/agent/list"})
msg = await client.receive_json()
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
await client.send_json_auto_id(
{"type": "conversation/agent/list", "language": "smurfish"}
)
msg = await client.receive_json()
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
await client.send_json_auto_id(
{"type": "conversation/agent/list", "language": "en"}
)
msg = await client.receive_json()
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
await client.send_json_auto_id(
{"type": "conversation/agent/list", "language": "en-UK"}
)
msg = await client.receive_json()
assert msg["id"] == 5
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
@ -1597,7 +1619,7 @@ async def test_get_agent_info(
"""Test get agent info."""
agent_info = conversation.async_get_agent_info(hass)
# Test it's the default
assert agent_info["id"] == mock_agent.agent_id
assert agent_info.id == mock_agent.agent_id
assert agent_info == snapshot
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot