Conversation: allow getting agent info (#90540)

* Conversation: allow getting agent info

* Add unset agenet back
This commit is contained in:
Paulus Schoutsen 2023-03-31 14:36:39 -04:00 committed by GitHub
parent 8018be28ee
commit ad26317b75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 11 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
import logging
import re
from typing import Any
from typing import Any, TypedDict
import voluptuous as vol
@ -20,6 +20,15 @@ from homeassistant.loader import bind_hass
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .default_agent import DefaultAgent
__all__ = [
"DOMAIN",
"async_converse",
"async_get_agent_info",
"async_set_agent",
"async_unset_agent",
"async_setup",
]
_LOGGER = logging.getLogger(__name__)
ATTR_TEXT = "text"
@ -270,6 +279,31 @@ class ConversationProcessView(http.HomeAssistantView):
return self.json(result.as_dict())
class AgentInfo(TypedDict):
"""Dictionary holding agent info."""
id: str
name: str
@core.callback
def async_get_agent_info(
hass: core.HomeAssistant,
agent_id: str | None = None,
) -> AgentInfo | None:
"""Get information on the agent or None if not found."""
manager = _get_agent_manager(hass)
if agent_id is None:
agent_id = manager.default_agent
for agent_info in manager.async_get_agent_info():
if agent_info["id"] == agent_id:
return agent_info
return None
async def async_converse(
hass: core.HomeAssistant,
text: str,
@ -332,12 +366,15 @@ class AgentManager:
return self._builtin_agent
if agent_id not in self._agents:
raise ValueError(f"Agent {agent_id} not found")
return self._agents[agent_id]
@core.callback
def async_get_agent_info(self) -> list[dict[str, Any]]:
def async_get_agent_info(self) -> list[AgentInfo]:
"""List all agents."""
agents = [
agents: list[AgentInfo] = [
{
"id": AgentManager.HOME_ASSISTANT_AGENT,
"name": "Home Assistant",

View File

@ -0,0 +1,34 @@
# serializer version: 1
# name: test_get_agent_info
dict({
'id': 'mock-entry',
'name': 'Mock Title',
})
# ---
# name: test_get_agent_info.1
dict({
'id': 'homeassistant',
'name': 'Home Assistant',
})
# ---
# name: test_get_agent_info.2
dict({
'id': 'mock-entry',
'name': 'Mock Title',
})
# ---
# name: test_get_agent_list
dict({
'agents': list([
dict({
'id': 'homeassistant',
'name': 'Home Assistant',
}),
dict({
'id': 'mock-entry',
'name': 'Mock Title',
}),
]),
'default_agent': 'mock-entry',
})
# ---

View File

@ -4,6 +4,7 @@ from typing import Any
from unittest.mock import patch
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
@ -929,7 +930,11 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
async def test_get_agent_list(
hass: HomeAssistant, init_components, mock_agent, hass_ws_client: WebSocketGenerator
hass: HomeAssistant,
init_components,
mock_agent,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test getting agent info."""
client = await hass_ws_client(hass)
@ -940,10 +945,17 @@ async def test_get_agent_list(
assert msg["id"] == 5
assert msg["type"] == "result"
assert msg["success"]
assert msg["result"] == {
"agents": [
{"id": "homeassistant", "name": "Home Assistant"},
{"id": "mock-entry", "name": "Mock Title"},
],
"default_agent": "mock-entry",
}
assert msg["result"] == snapshot
async def test_get_agent_info(
hass: HomeAssistant, init_components, mock_agent, snapshot: SnapshotAssertion
) -> None:
"""Test get agent info."""
agent_info = conversation.async_get_agent_info(hass)
# Test it's the default
assert agent_info["id"] == mock_agent.agent_id
assert agent_info == snapshot
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
assert conversation.async_get_agent_info(hass, "not exist") is None