diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index e2e00a2652a..5009530dc31 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio import logging import re -from typing import Any +from typing import Any, TypedDict import voluptuous as vol @@ -20,6 +20,15 @@ from homeassistant.loader import bind_hass from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .default_agent import DefaultAgent +__all__ = [ + "DOMAIN", + "async_converse", + "async_get_agent_info", + "async_set_agent", + "async_unset_agent", + "async_setup", +] + _LOGGER = logging.getLogger(__name__) ATTR_TEXT = "text" @@ -270,6 +279,31 @@ class ConversationProcessView(http.HomeAssistantView): return self.json(result.as_dict()) +class AgentInfo(TypedDict): + """Dictionary holding agent info.""" + + id: str + name: str + + +@core.callback +def async_get_agent_info( + hass: core.HomeAssistant, + agent_id: str | None = None, +) -> AgentInfo | None: + """Get information on the agent or None if not found.""" + manager = _get_agent_manager(hass) + + if agent_id is None: + agent_id = manager.default_agent + + for agent_info in manager.async_get_agent_info(): + if agent_info["id"] == agent_id: + return agent_info + + return None + + async def async_converse( hass: core.HomeAssistant, text: str, @@ -332,12 +366,15 @@ class AgentManager: return self._builtin_agent + if agent_id not in self._agents: + raise ValueError(f"Agent {agent_id} not found") + return self._agents[agent_id] @core.callback - def async_get_agent_info(self) -> list[dict[str, Any]]: + def async_get_agent_info(self) -> list[AgentInfo]: """List all agents.""" - agents = [ + agents: list[AgentInfo] = [ { "id": AgentManager.HOME_ASSISTANT_AGENT, "name": "Home Assistant", diff --git a/tests/components/conversation/snapshots/test_init.ambr b/tests/components/conversation/snapshots/test_init.ambr new file mode 100644 index 00000000000..1547b5b5e88 --- /dev/null +++ b/tests/components/conversation/snapshots/test_init.ambr @@ -0,0 +1,34 @@ +# serializer version: 1 +# name: test_get_agent_info + dict({ + 'id': 'mock-entry', + 'name': 'Mock Title', + }) +# --- +# name: test_get_agent_info.1 + dict({ + 'id': 'homeassistant', + 'name': 'Home Assistant', + }) +# --- +# name: test_get_agent_info.2 + dict({ + 'id': 'mock-entry', + 'name': 'Mock Title', + }) +# --- +# name: test_get_agent_list + dict({ + 'agents': list([ + dict({ + 'id': 'homeassistant', + 'name': 'Home Assistant', + }), + dict({ + 'id': 'mock-entry', + 'name': 'Mock Title', + }), + ]), + 'default_agent': 'mock-entry', + }) +# --- diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index 55a345bd605..eb38d875bfa 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -4,6 +4,7 @@ from typing import Any from unittest.mock import patch import pytest +from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation @@ -929,7 +930,11 @@ async def test_agent_id_validator_invalid_agent(hass: HomeAssistant) -> None: async def test_get_agent_list( - hass: HomeAssistant, init_components, mock_agent, hass_ws_client: WebSocketGenerator + hass: HomeAssistant, + init_components, + mock_agent, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, ) -> None: """Test getting agent info.""" client = await hass_ws_client(hass) @@ -940,10 +945,17 @@ async def test_get_agent_list( assert msg["id"] == 5 assert msg["type"] == "result" assert msg["success"] - assert msg["result"] == { - "agents": [ - {"id": "homeassistant", "name": "Home Assistant"}, - {"id": "mock-entry", "name": "Mock Title"}, - ], - "default_agent": "mock-entry", - } + assert msg["result"] == snapshot + + +async def test_get_agent_info( + hass: HomeAssistant, init_components, mock_agent, snapshot: SnapshotAssertion +) -> None: + """Test get agent info.""" + agent_info = conversation.async_get_agent_info(hass) + # Test it's the default + assert agent_info["id"] == mock_agent.agent_id + assert agent_info == snapshot + assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot + assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot + assert conversation.async_get_agent_info(hass, "not exist") is None