mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Refactor conversation agent WS API for listing agents (#91590)
* Refactor conversation agent WS API for listing agents * Add conversation/agent/info back
This commit is contained in:
parent
9c784ac622
commit
5e9bbeb4ad
@ -337,7 +337,7 @@ class PipelineRun:
|
|||||||
message=f"Intent recognition engine {engine} is not found",
|
message=f"Intent recognition engine {engine} is not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.intent_agent = agent_info["id"]
|
self.intent_agent = agent_info.id
|
||||||
|
|
||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self, intent_input: str, conversation_id: str | None
|
self, intent_input: str, conversation_id: str | None
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, TypedDict
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers import config_validation as cv, intent, singleton
|
from homeassistant.helpers import config_validation as cv, intent, singleton
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
from homeassistant.util import language as language_util
|
||||||
|
|
||||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
from .const import HOME_ASSISTANT_AGENT
|
from .const import HOME_ASSISTANT_AGENT
|
||||||
@ -229,24 +231,29 @@ async def websocket_get_agent_info(
|
|||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "conversation/agent/list",
|
vol.Required("type"): "conversation/agent/list",
|
||||||
|
vol.Optional("language"): str,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@core.callback
|
@websocket_api.async_response
|
||||||
def websocket_list_agents(
|
async def websocket_list_agents(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
connection: websocket_api.ActiveConnection,
|
|
||||||
msg: dict[str, Any],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List available agents."""
|
"""List conversation agents and, optionally, if they support a given language."""
|
||||||
manager = _get_agent_manager(hass)
|
manager = _get_agent_manager(hass)
|
||||||
|
|
||||||
connection.send_result(
|
language = msg.get("language")
|
||||||
msg["id"],
|
agents = []
|
||||||
{
|
|
||||||
"default_agent": manager.default_agent,
|
for agent_info in manager.async_get_agent_info():
|
||||||
"agents": manager.async_get_agent_info(),
|
agent_dict: dict[str, Any] = {"id": agent_info.id, "name": agent_info.name}
|
||||||
},
|
if language:
|
||||||
|
agent = await manager.async_get_agent(agent_info.id)
|
||||||
|
agent_dict["language_supported"] = bool(
|
||||||
|
language_util.matches(language, agent.supported_languages)
|
||||||
)
|
)
|
||||||
|
agents.append(agent_dict)
|
||||||
|
|
||||||
|
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessView(http.HomeAssistantView):
|
class ConversationProcessView(http.HomeAssistantView):
|
||||||
@ -281,8 +288,9 @@ class ConversationProcessView(http.HomeAssistantView):
|
|||||||
return self.json(result.as_dict())
|
return self.json(result.as_dict())
|
||||||
|
|
||||||
|
|
||||||
class AgentInfo(TypedDict):
|
@dataclass(frozen=True)
|
||||||
"""Dictionary holding agent info."""
|
class AgentInfo:
|
||||||
|
"""Container for conversation agent info."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@ -300,7 +308,7 @@ def async_get_agent_info(
|
|||||||
agent_id = manager.default_agent
|
agent_id = manager.default_agent
|
||||||
|
|
||||||
for agent_info in manager.async_get_agent_info():
|
for agent_info in manager.async_get_agent_info():
|
||||||
if agent_info["id"] == agent_id:
|
if agent_info.id == agent_id:
|
||||||
return agent_info
|
return agent_info
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -375,10 +383,10 @@ class AgentManager:
|
|||||||
def async_get_agent_info(self) -> list[AgentInfo]:
|
def async_get_agent_info(self) -> list[AgentInfo]:
|
||||||
"""List all agents."""
|
"""List all agents."""
|
||||||
agents: list[AgentInfo] = [
|
agents: list[AgentInfo] = [
|
||||||
{
|
AgentInfo(
|
||||||
"id": HOME_ASSISTANT_AGENT,
|
id=HOME_ASSISTANT_AGENT,
|
||||||
"name": "Home Assistant",
|
name="Home Assistant",
|
||||||
}
|
)
|
||||||
]
|
]
|
||||||
for agent_id, agent in self._agents.items():
|
for agent_id, agent in self._agents.items():
|
||||||
config_entry = self.hass.config_entries.async_get_entry(agent_id)
|
config_entry = self.hass.config_entries.async_get_entry(agent_id)
|
||||||
@ -393,10 +401,10 @@ class AgentManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
agents.append(
|
agents.append(
|
||||||
{
|
AgentInfo(
|
||||||
"id": agent_id,
|
id=agent_id,
|
||||||
"name": config_entry.title,
|
name=config_entry.title,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
@ -29,6 +29,53 @@
|
|||||||
'name': 'Mock Title',
|
'name': 'Mock Title',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
'default_agent': 'mock-entry',
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_get_agent_list.1
|
||||||
|
dict({
|
||||||
|
'agents': list([
|
||||||
|
dict({
|
||||||
|
'id': 'homeassistant',
|
||||||
|
'language_supported': False,
|
||||||
|
'name': 'Home Assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-entry',
|
||||||
|
'language_supported': True,
|
||||||
|
'name': 'Mock Title',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_get_agent_list.2
|
||||||
|
dict({
|
||||||
|
'agents': list([
|
||||||
|
dict({
|
||||||
|
'id': 'homeassistant',
|
||||||
|
'language_supported': True,
|
||||||
|
'name': 'Home Assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-entry',
|
||||||
|
'language_supported': False,
|
||||||
|
'name': 'Mock Title',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_get_agent_list.3
|
||||||
|
dict({
|
||||||
|
'agents': list([
|
||||||
|
dict({
|
||||||
|
'id': 'homeassistant',
|
||||||
|
'language_supported': True,
|
||||||
|
'name': 'Home Assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-entry',
|
||||||
|
'language_supported': False,
|
||||||
|
'name': 'Mock Title',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -1582,10 +1582,32 @@ async def test_get_agent_list(
|
|||||||
"""Test getting agent info."""
|
"""Test getting agent info."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await client.send_json({"id": 5, "type": "conversation/agent/list"})
|
await client.send_json_auto_id({"type": "conversation/agent/list"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["type"] == "result"
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == snapshot
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{"type": "conversation/agent/list", "language": "smurfish"}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["type"] == "result"
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == snapshot
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{"type": "conversation/agent/list", "language": "en"}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["type"] == "result"
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == snapshot
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{"type": "conversation/agent/list", "language": "en-UK"}
|
||||||
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["id"] == 5
|
|
||||||
assert msg["type"] == "result"
|
assert msg["type"] == "result"
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == snapshot
|
assert msg["result"] == snapshot
|
||||||
@ -1597,7 +1619,7 @@ async def test_get_agent_info(
|
|||||||
"""Test get agent info."""
|
"""Test get agent info."""
|
||||||
agent_info = conversation.async_get_agent_info(hass)
|
agent_info = conversation.async_get_agent_info(hass)
|
||||||
# Test it's the default
|
# Test it's the default
|
||||||
assert agent_info["id"] == mock_agent.agent_id
|
assert agent_info.id == mock_agent.agent_id
|
||||||
assert agent_info == snapshot
|
assert agent_info == snapshot
|
||||||
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
||||||
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user