Refactor conversation agent WS API for listing agents (#91590)

* Refactor conversation agent WS API for listing agents

* Add conversation/agent/info back
This commit is contained in:
Erik Montnemery 2023-04-19 16:53:24 +02:00 committed by GitHub
parent 9c784ac622
commit 5e9bbeb4ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 31 deletions

View File

@ -337,7 +337,7 @@ class PipelineRun:
message=f"Intent recognition engine {engine} is not found", message=f"Intent recognition engine {engine} is not found",
) )
self.intent_agent = agent_info["id"] self.intent_agent = agent_info.id
async def recognize_intent( async def recognize_intent(
self, intent_input: str, conversation_id: str | None self, intent_input: str, conversation_id: str | None

View File

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass
import logging import logging
import re import re
from typing import Any, TypedDict from typing import Any
import voluptuous as vol import voluptuous as vol
@ -16,6 +17,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent, singleton from homeassistant.helpers import config_validation as cv, intent, singleton
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import language as language_util
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import HOME_ASSISTANT_AGENT from .const import HOME_ASSISTANT_AGENT
@ -229,24 +231,29 @@ async def websocket_get_agent_info(
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "conversation/agent/list", vol.Required("type"): "conversation/agent/list",
vol.Optional("language"): str,
} }
) )
@core.callback @websocket_api.async_response
def websocket_list_agents( async def websocket_list_agents(
hass: HomeAssistant, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None: ) -> None:
"""List available agents.""" """List conversation agents and, optionally, if they support a given language."""
manager = _get_agent_manager(hass) manager = _get_agent_manager(hass)
connection.send_result( language = msg.get("language")
msg["id"], agents = []
{
"default_agent": manager.default_agent, for agent_info in manager.async_get_agent_info():
"agents": manager.async_get_agent_info(), agent_dict: dict[str, Any] = {"id": agent_info.id, "name": agent_info.name}
}, if language:
agent = await manager.async_get_agent(agent_info.id)
agent_dict["language_supported"] = bool(
language_util.matches(language, agent.supported_languages)
) )
agents.append(agent_dict)
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
class ConversationProcessView(http.HomeAssistantView): class ConversationProcessView(http.HomeAssistantView):
@ -281,8 +288,9 @@ class ConversationProcessView(http.HomeAssistantView):
return self.json(result.as_dict()) return self.json(result.as_dict())
class AgentInfo(TypedDict): @dataclass(frozen=True)
"""Dictionary holding agent info.""" class AgentInfo:
"""Container for conversation agent info."""
id: str id: str
name: str name: str
@ -300,7 +308,7 @@ def async_get_agent_info(
agent_id = manager.default_agent agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info(): for agent_info in manager.async_get_agent_info():
if agent_info["id"] == agent_id: if agent_info.id == agent_id:
return agent_info return agent_info
return None return None
@ -375,10 +383,10 @@ class AgentManager:
def async_get_agent_info(self) -> list[AgentInfo]: def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents.""" """List all agents."""
agents: list[AgentInfo] = [ agents: list[AgentInfo] = [
{ AgentInfo(
"id": HOME_ASSISTANT_AGENT, id=HOME_ASSISTANT_AGENT,
"name": "Home Assistant", name="Home Assistant",
} )
] ]
for agent_id, agent in self._agents.items(): for agent_id, agent in self._agents.items():
config_entry = self.hass.config_entries.async_get_entry(agent_id) config_entry = self.hass.config_entries.async_get_entry(agent_id)
@ -393,10 +401,10 @@ class AgentManager:
continue continue
agents.append( agents.append(
{ AgentInfo(
"id": agent_id, id=agent_id,
"name": config_entry.title, name=config_entry.title,
} )
) )
return agents return agents

View File

@ -29,6 +29,53 @@
'name': 'Mock Title', 'name': 'Mock Title',
}), }),
]), ]),
'default_agent': 'mock-entry', })
# ---
# name: test_get_agent_list.1
dict({
'agents': list([
dict({
'id': 'homeassistant',
'language_supported': False,
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'language_supported': True,
'name': 'Mock Title',
}),
]),
})
# ---
# name: test_get_agent_list.2
dict({
'agents': list([
dict({
'id': 'homeassistant',
'language_supported': True,
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'language_supported': False,
'name': 'Mock Title',
}),
]),
})
# ---
# name: test_get_agent_list.3
dict({
'agents': list([
dict({
'id': 'homeassistant',
'language_supported': True,
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'language_supported': False,
'name': 'Mock Title',
}),
]),
}) })
# --- # ---

View File

@ -1582,10 +1582,32 @@ async def test_get_agent_list(
"""Test getting agent info.""" """Test getting agent info."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json({"id": 5, "type": "conversation/agent/list"}) await client.send_json_auto_id({"type": "conversation/agent/list"})
msg = await client.receive_json()
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
await client.send_json_auto_id(
{"type": "conversation/agent/list", "language": "smurfish"}
)
msg = await client.receive_json()
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
await client.send_json_auto_id(
{"type": "conversation/agent/list", "language": "en"}
)
msg = await client.receive_json()
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == snapshot
await client.send_json_auto_id(
{"type": "conversation/agent/list", "language": "en-UK"}
)
msg = await client.receive_json() msg = await client.receive_json()
assert msg["id"] == 5
assert msg["type"] == "result" assert msg["type"] == "result"
assert msg["success"] assert msg["success"]
assert msg["result"] == snapshot assert msg["result"] == snapshot
@ -1597,7 +1619,7 @@ async def test_get_agent_info(
"""Test get agent info.""" """Test get agent info."""
agent_info = conversation.async_get_agent_info(hass) agent_info = conversation.async_get_agent_info(hass)
# Test it's the default # Test it's the default
assert agent_info["id"] == mock_agent.agent_id assert agent_info.id == mock_agent.agent_id
assert agent_info == snapshot assert agent_info == snapshot
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot