Conversation: allow getting agent info (#90540)

* Conversation: allow getting agent info

* Add unset agenet back
This commit is contained in:
Paulus Schoutsen 2023-03-31 14:36:39 -04:00 committed by GitHub
parent 8018be28ee
commit ad26317b75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 11 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import re import re
from typing import Any from typing import Any, TypedDict
import voluptuous as vol import voluptuous as vol
@ -20,6 +20,15 @@ from homeassistant.loader import bind_hass
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .default_agent import DefaultAgent from .default_agent import DefaultAgent
__all__ = [
"DOMAIN",
"async_converse",
"async_get_agent_info",
"async_set_agent",
"async_unset_agent",
"async_setup",
]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_TEXT = "text" ATTR_TEXT = "text"
@ -270,6 +279,31 @@ class ConversationProcessView(http.HomeAssistantView):
return self.json(result.as_dict()) return self.json(result.as_dict())
class AgentInfo(TypedDict):
"""Dictionary holding agent info."""
id: str
name: str
@core.callback
def async_get_agent_info(
hass: core.HomeAssistant,
agent_id: str | None = None,
) -> AgentInfo | None:
"""Get information on the agent or None if not found."""
manager = _get_agent_manager(hass)
if agent_id is None:
agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info():
if agent_info["id"] == agent_id:
return agent_info
return None
async def async_converse( async def async_converse(
hass: core.HomeAssistant, hass: core.HomeAssistant,
text: str, text: str,
@ -332,12 +366,15 @@ class AgentManager:
return self._builtin_agent return self._builtin_agent
if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found")
return self._agents[agent_id] return self._agents[agent_id]
@core.callback @core.callback
def async_get_agent_info(self) -> list[dict[str, Any]]: def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents.""" """List all agents."""
agents = [ agents: list[AgentInfo] = [
{ {
"id": AgentManager.HOME_ASSISTANT_AGENT, "id": AgentManager.HOME_ASSISTANT_AGENT,
"name": "Home Assistant", "name": "Home Assistant",

View File

@ -0,0 +1,34 @@
# serializer version: 1
# name: test_get_agent_info
dict({
'id': 'mock-entry',
'name': 'Mock Title',
})
# ---
# name: test_get_agent_info.1
dict({
'id': 'homeassistant',
'name': 'Home Assistant',
})
# ---
# name: test_get_agent_info.2
dict({
'id': 'mock-entry',
'name': 'Mock Title',
})
# ---
# name: test_get_agent_list
dict({
'agents': list([
dict({
'id': 'homeassistant',
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'name': 'Mock Title',
}),
]),
'default_agent': 'mock-entry',
})
# ---

View File

@ -4,6 +4,7 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
@ -929,7 +930,11 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
async def test_get_agent_list( async def test_get_agent_list(
hass: HomeAssistant, init_components, mock_agent, hass_ws_client: WebSocketGenerator hass: HomeAssistant,
init_components,
mock_agent,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test getting agent info.""" """Test getting agent info."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -940,10 +945,17 @@ async def test_get_agent_list(
assert msg["id"] == 5 assert msg["id"] == 5
assert msg["type"] == "result" assert msg["type"] == "result"
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == snapshot
"agents": [
{"id": "homeassistant", "name": "Home Assistant"},
{"id": "mock-entry", "name": "Mock Title"}, async def test_get_agent_info(
], hass: HomeAssistant, init_components, mock_agent, snapshot: SnapshotAssertion
"default_agent": "mock-entry", ) -> None:
} """Test get agent info."""
agent_info = conversation.async_get_agent_info(hass)
# Test it's the default
assert agent_info["id"] == mock_agent.agent_id
assert agent_info == snapshot
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
assert conversation.async_get_agent_info(hass, "not exist") is None