mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Conversation: allow getting agent info (#90540)
* Conversation: allow getting agent info * Add unset agenet back
This commit is contained in:
parent
8018be28ee
commit
ad26317b75
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -20,6 +20,15 @@ from homeassistant.loader import bind_hass
|
|||||||
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
from .default_agent import DefaultAgent
|
from .default_agent import DefaultAgent
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DOMAIN",
|
||||||
|
"async_converse",
|
||||||
|
"async_get_agent_info",
|
||||||
|
"async_set_agent",
|
||||||
|
"async_unset_agent",
|
||||||
|
"async_setup",
|
||||||
|
]
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
ATTR_TEXT = "text"
|
ATTR_TEXT = "text"
|
||||||
@ -270,6 +279,31 @@ class ConversationProcessView(http.HomeAssistantView):
|
|||||||
return self.json(result.as_dict())
|
return self.json(result.as_dict())
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInfo(TypedDict):
|
||||||
|
"""Dictionary holding agent info."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@core.callback
|
||||||
|
def async_get_agent_info(
|
||||||
|
hass: core.HomeAssistant,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
) -> AgentInfo | None:
|
||||||
|
"""Get information on the agent or None if not found."""
|
||||||
|
manager = _get_agent_manager(hass)
|
||||||
|
|
||||||
|
if agent_id is None:
|
||||||
|
agent_id = manager.default_agent
|
||||||
|
|
||||||
|
for agent_info in manager.async_get_agent_info():
|
||||||
|
if agent_info["id"] == agent_id:
|
||||||
|
return agent_info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def async_converse(
|
async def async_converse(
|
||||||
hass: core.HomeAssistant,
|
hass: core.HomeAssistant,
|
||||||
text: str,
|
text: str,
|
||||||
@ -332,12 +366,15 @@ class AgentManager:
|
|||||||
|
|
||||||
return self._builtin_agent
|
return self._builtin_agent
|
||||||
|
|
||||||
|
if agent_id not in self._agents:
|
||||||
|
raise ValueError(f"Agent {agent_id} not found")
|
||||||
|
|
||||||
return self._agents[agent_id]
|
return self._agents[agent_id]
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_get_agent_info(self) -> list[dict[str, Any]]:
|
def async_get_agent_info(self) -> list[AgentInfo]:
|
||||||
"""List all agents."""
|
"""List all agents."""
|
||||||
agents = [
|
agents: list[AgentInfo] = [
|
||||||
{
|
{
|
||||||
"id": AgentManager.HOME_ASSISTANT_AGENT,
|
"id": AgentManager.HOME_ASSISTANT_AGENT,
|
||||||
"name": "Home Assistant",
|
"name": "Home Assistant",
|
||||||
|
34
tests/components/conversation/snapshots/test_init.ambr
Normal file
34
tests/components/conversation/snapshots/test_init.ambr
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_get_agent_info
|
||||||
|
dict({
|
||||||
|
'id': 'mock-entry',
|
||||||
|
'name': 'Mock Title',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_get_agent_info.1
|
||||||
|
dict({
|
||||||
|
'id': 'homeassistant',
|
||||||
|
'name': 'Home Assistant',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_get_agent_info.2
|
||||||
|
dict({
|
||||||
|
'id': 'mock-entry',
|
||||||
|
'name': 'Mock Title',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_get_agent_list
|
||||||
|
dict({
|
||||||
|
'agents': list([
|
||||||
|
dict({
|
||||||
|
'id': 'homeassistant',
|
||||||
|
'name': 'Home Assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-entry',
|
||||||
|
'name': 'Mock Title',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'default_agent': 'mock-entry',
|
||||||
|
})
|
||||||
|
# ---
|
@ -4,6 +4,7 @@ from typing import Any
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
@ -929,7 +930,11 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_agent_list(
|
async def test_get_agent_list(
|
||||||
hass: HomeAssistant, init_components, mock_agent, hass_ws_client: WebSocketGenerator
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_agent,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting agent info."""
|
"""Test getting agent info."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
@ -940,10 +945,17 @@ async def test_get_agent_list(
|
|||||||
assert msg["id"] == 5
|
assert msg["id"] == 5
|
||||||
assert msg["type"] == "result"
|
assert msg["type"] == "result"
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == snapshot
|
||||||
"agents": [
|
|
||||||
{"id": "homeassistant", "name": "Home Assistant"},
|
|
||||||
{"id": "mock-entry", "name": "Mock Title"},
|
async def test_get_agent_info(
|
||||||
],
|
hass: HomeAssistant, init_components, mock_agent, snapshot: SnapshotAssertion
|
||||||
"default_agent": "mock-entry",
|
) -> None:
|
||||||
}
|
"""Test get agent info."""
|
||||||
|
agent_info = conversation.async_get_agent_info(hass)
|
||||||
|
# Test it's the default
|
||||||
|
assert agent_info["id"] == mock_agent.agent_id
|
||||||
|
assert agent_info == snapshot
|
||||||
|
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
||||||
|
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
||||||
|
assert conversation.async_get_agent_info(hass, "not exist") is None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user