mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Change conversation default agent behavior (#95225)
* Change conversation default agent behavior * Fix tests
This commit is contained in:
parent
c4288e7b1f
commit
d6cd5648b9
@ -529,12 +529,8 @@ class AgentManager:
|
||||
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
||||
"""Set the agent."""
|
||||
self._agents[agent_id] = agent
|
||||
if self.default_agent == HOME_ASSISTANT_AGENT:
|
||||
self.default_agent = agent_id
|
||||
|
||||
@core.callback
|
||||
def async_unset_agent(self, agent_id: str) -> None:
|
||||
"""Unset the agent."""
|
||||
if self.default_agent == agent_id:
|
||||
self.default_agent = HOME_ASSISTANT_AGENT
|
||||
self._agents.pop(agent_id, None)
|
||||
|
@ -1,22 +1,22 @@
|
||||
# serializer version: 1
|
||||
# name: test_get_agent_info
|
||||
dict({
|
||||
'id': 'mock-entry',
|
||||
'name': 'Mock Title',
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_info.1
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'name': 'Home Assistant',
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_info.2
|
||||
# name: test_get_agent_info.1
|
||||
dict({
|
||||
'id': 'mock-entry',
|
||||
'name': 'Mock Title',
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_info.2
|
||||
dict({
|
||||
'id': 'homeassistant',
|
||||
'name': 'Home Assistant',
|
||||
})
|
||||
# ---
|
||||
# name: test_get_agent_info.3
|
||||
dict({
|
||||
'id': 'mock-entry',
|
||||
@ -344,10 +344,7 @@
|
||||
# ---
|
||||
# name: test_ws_get_agent_info
|
||||
dict({
|
||||
'attribution': dict({
|
||||
'name': 'Mock assistant',
|
||||
'url': 'https://assist.me',
|
||||
}),
|
||||
'attribution': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_ws_get_agent_info.1
|
||||
|
@ -1054,16 +1054,16 @@ async def test_http_api_wrong_data(
|
||||
assert resp.status == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", (None, "mock-entry"))
|
||||
async def test_custom_agent(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
hass_admin_user: MockUser,
|
||||
mock_agent,
|
||||
agent_id,
|
||||
) -> None:
|
||||
"""Test a custom conversation agent."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
@ -1071,9 +1071,8 @@ async def test_custom_agent(
|
||||
"text": "Test Text",
|
||||
"conversation_id": "test-conv-id",
|
||||
"language": "test-language",
|
||||
"agent_id": mock_agent.agent_id,
|
||||
}
|
||||
if agent_id is not None:
|
||||
data["agent_id"] = agent_id
|
||||
|
||||
resp = await client.post("/api/conversation/process", json=data)
|
||||
assert resp.status == HTTPStatus.OK
|
||||
@ -1599,8 +1598,7 @@ async def test_get_agent_info(
|
||||
"""Test get agent info."""
|
||||
agent_info = conversation.async_get_agent_info(hass)
|
||||
# Test it's the default
|
||||
assert agent_info.id == mock_agent.agent_id
|
||||
assert agent_info == snapshot
|
||||
assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info
|
||||
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
||||
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
||||
assert conversation.async_get_agent_info(hass, "not exist") is None
|
||||
|
@ -16,7 +16,7 @@ from homeassistant.util.dt import utcnow
|
||||
|
||||
from .conftest import ComponentSetup, ExpectedCredentials
|
||||
|
||||
from tests.common import async_fire_time_changed, async_mock_service
|
||||
from tests.common import MockConfigEntry, async_fire_time_changed, async_mock_service
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
@ -322,6 +322,7 @@ async def test_send_text_command_media_player(
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: ComponentSetup,
|
||||
config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test GoogleAssistantConversationAgent."""
|
||||
await setup_integration()
|
||||
@ -348,13 +349,13 @@ async def test_conversation_agent(
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{"text": text1},
|
||||
{"text": text1, "agent_id": config_entry.entry_id},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{"text": text2},
|
||||
{"text": text2, "agent_id": config_entry.entry_id},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
@ -367,6 +368,7 @@ async def test_conversation_agent(
|
||||
|
||||
async def test_conversation_agent_refresh_token(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
setup_integration: ComponentSetup,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
@ -392,7 +394,7 @@ async def test_conversation_agent_refresh_token(
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{"text": text1},
|
||||
{"text": text1, "agent_id": config_entry.entry_id},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
@ -412,7 +414,7 @@ async def test_conversation_agent_refresh_token(
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
{"text": text2},
|
||||
{"text": text2, "agent_id": config_entry.entry_id},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,7 @@ from tests.common import MockConfigEntry
|
||||
|
||||
async def test_default_prompt(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
@ -89,16 +90,22 @@ async def test_default_prompt(
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
with patch("google.generativeai.chat_async") as mock_chat:
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_chat.mock_calls[0][2] == snapshot
|
||||
|
||||
|
||||
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
with patch("google.generativeai.chat_async", side_effect=ClientError("")):
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
@ -119,7 +126,9 @@ async def test_template_error(
|
||||
), patch("google.generativeai.chat_async"):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
@ -1026,15 +1026,19 @@ async def test_webhook_handle_conversation_process(
|
||||
"""Test that we can converse."""
|
||||
webhook_client.server.app.router._frozen = False
|
||||
|
||||
resp = await webhook_client.post(
|
||||
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),
|
||||
json={
|
||||
"type": "conversation_process",
|
||||
"data": {
|
||||
"text": "Turn the kitchen light off",
|
||||
with patch(
|
||||
"homeassistant.components.conversation.AgentManager.async_get_agent",
|
||||
return_value=mock_agent,
|
||||
):
|
||||
resp = await webhook_client.post(
|
||||
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),
|
||||
json={
|
||||
"type": "conversation_process",
|
||||
"data": {
|
||||
"text": "Turn the kitchen light off",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
json = await resp.json()
|
||||
|
@ -13,6 +13,7 @@ from tests.common import MockConfigEntry
|
||||
|
||||
async def test_default_prompt(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
@ -101,18 +102,24 @@ async def test_default_prompt(
|
||||
]
|
||||
},
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
||||
|
||||
|
||||
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
with patch(
|
||||
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
|
||||
):
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
@ -133,7 +140,9 @@ async def test_template_error(
|
||||
), patch("openai.ChatCompletion.acreate"):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
Loading…
x
Reference in New Issue
Block a user