Change conversation default agent behavior (#95225)

* Change conversation default agent behavior

* Fix tests
This commit is contained in:
Paulus Schoutsen 2023-06-26 22:10:17 -04:00 committed by GitHub
parent c4288e7b1f
commit d6cd5648b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 42 deletions

View File

@ -529,12 +529,8 @@ class AgentManager:
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
"""Set the agent."""
self._agents[agent_id] = agent
if self.default_agent == HOME_ASSISTANT_AGENT:
self.default_agent = agent_id
@core.callback
def async_unset_agent(self, agent_id: str) -> None:
"""Unset the agent."""
if self.default_agent == agent_id:
self.default_agent = HOME_ASSISTANT_AGENT
self._agents.pop(agent_id, None)

View File

@ -1,22 +1,22 @@
# serializer version: 1
# name: test_get_agent_info
dict({
'id': 'mock-entry',
'name': 'Mock Title',
})
# ---
# name: test_get_agent_info.1
dict({
'id': 'homeassistant',
'name': 'Home Assistant',
})
# ---
# name: test_get_agent_info.2
# name: test_get_agent_info.1
dict({
'id': 'mock-entry',
'name': 'Mock Title',
})
# ---
# name: test_get_agent_info.2
dict({
'id': 'homeassistant',
'name': 'Home Assistant',
})
# ---
# name: test_get_agent_info.3
dict({
'id': 'mock-entry',
@ -344,10 +344,7 @@
# ---
# name: test_ws_get_agent_info
dict({
'attribution': dict({
'name': 'Mock assistant',
'url': 'https://assist.me',
}),
'attribution': None,
})
# ---
# name: test_ws_get_agent_info.1

View File

@ -1054,16 +1054,16 @@ async def test_http_api_wrong_data(
assert resp.status == HTTPStatus.BAD_REQUEST
@pytest.mark.parametrize("agent_id", (None, "mock-entry"))
async def test_custom_agent(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_admin_user: MockUser,
mock_agent,
agent_id,
) -> None:
"""Test a custom conversation agent."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
client = await hass_client()
@ -1071,9 +1071,8 @@ async def test_custom_agent(
"text": "Test Text",
"conversation_id": "test-conv-id",
"language": "test-language",
"agent_id": mock_agent.agent_id,
}
if agent_id is not None:
data["agent_id"] = agent_id
resp = await client.post("/api/conversation/process", json=data)
assert resp.status == HTTPStatus.OK
@ -1599,8 +1598,7 @@ async def test_get_agent_info(
"""Test get agent info."""
agent_info = conversation.async_get_agent_info(hass)
# Test it's the default
assert agent_info.id == mock_agent.agent_id
assert agent_info == snapshot
assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
assert conversation.async_get_agent_info(hass, "not exist") is None

View File

@ -16,7 +16,7 @@ from homeassistant.util.dt import utcnow
from .conftest import ComponentSetup, ExpectedCredentials
from tests.common import async_fire_time_changed, async_mock_service
from tests.common import MockConfigEntry, async_fire_time_changed, async_mock_service
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
@ -322,6 +322,7 @@ async def test_send_text_command_media_player(
async def test_conversation_agent(
hass: HomeAssistant,
setup_integration: ComponentSetup,
config_entry: MockConfigEntry,
) -> None:
"""Test GoogleAssistantConversationAgent."""
await setup_integration()
@ -348,13 +349,13 @@ async def test_conversation_agent(
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
{"text": text1, "agent_id": config_entry.entry_id},
blocking=True,
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
{"text": text2, "agent_id": config_entry.entry_id},
blocking=True,
)
@ -367,6 +368,7 @@ async def test_conversation_agent(
async def test_conversation_agent_refresh_token(
hass: HomeAssistant,
config_entry: MockConfigEntry,
setup_integration: ComponentSetup,
aioclient_mock: AiohttpClientMocker,
) -> None:
@ -392,7 +394,7 @@ async def test_conversation_agent_refresh_token(
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
{"text": text1, "agent_id": config_entry.entry_id},
blocking=True,
)
@ -412,7 +414,7 @@ async def test_conversation_agent_refresh_token(
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
{"text": text2, "agent_id": config_entry.entry_id},
blocking=True,
)

View File

@ -13,6 +13,7 @@ from tests.common import MockConfigEntry
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
@ -89,16 +90,22 @@ async def test_default_prompt(
suggested_area="Test Area 2",
)
with patch("google.generativeai.chat_async") as mock_chat:
result = await conversation.async_converse(hass, "hello", None, Context())
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_chat.mock_calls[0][2] == snapshot
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that the default prompt works."""
with patch("google.generativeai.chat_async", side_effect=ClientError("")):
result = await conversation.async_converse(hass, "hello", None, Context())
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
@ -119,7 +126,9 @@ async def test_template_error(
), patch("google.generativeai.chat_async"):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(hass, "hello", None, Context())
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result

View File

@ -1026,15 +1026,19 @@ async def test_webhook_handle_conversation_process(
"""Test that we can converse."""
webhook_client.server.app.router._frozen = False
resp = await webhook_client.post(
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),
json={
"type": "conversation_process",
"data": {
"text": "Turn the kitchen light off",
with patch(
"homeassistant.components.conversation.AgentManager.async_get_agent",
return_value=mock_agent,
):
resp = await webhook_client.post(
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),
json={
"type": "conversation_process",
"data": {
"text": "Turn the kitchen light off",
},
},
},
)
)
assert resp.status == HTTPStatus.OK
json = await resp.json()

View File

@ -13,6 +13,7 @@ from tests.common import MockConfigEntry
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
@ -101,18 +102,24 @@ async def test_default_prompt(
]
},
) as mock_create:
result = await conversation.async_converse(hass, "hello", None, Context())
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[0][2]["messages"] == snapshot
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that the default prompt works."""
with patch(
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
):
result = await conversation.async_converse(hass, "hello", None, Context())
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
@ -133,7 +140,9 @@ async def test_template_error(
), patch("openai.ChatCompletion.acreate"):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(hass, "hello", None, Context())
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result