mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Change conversation default agent behavior (#95225)
* Change conversation default agent behavior * Fix tests
This commit is contained in:
parent
c4288e7b1f
commit
d6cd5648b9
@ -529,12 +529,8 @@ class AgentManager:
|
|||||||
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None:
|
||||||
"""Set the agent."""
|
"""Set the agent."""
|
||||||
self._agents[agent_id] = agent
|
self._agents[agent_id] = agent
|
||||||
if self.default_agent == HOME_ASSISTANT_AGENT:
|
|
||||||
self.default_agent = agent_id
|
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
def async_unset_agent(self, agent_id: str) -> None:
|
def async_unset_agent(self, agent_id: str) -> None:
|
||||||
"""Unset the agent."""
|
"""Unset the agent."""
|
||||||
if self.default_agent == agent_id:
|
|
||||||
self.default_agent = HOME_ASSISTANT_AGENT
|
|
||||||
self._agents.pop(agent_id, None)
|
self._agents.pop(agent_id, None)
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_get_agent_info
|
# name: test_get_agent_info
|
||||||
dict({
|
|
||||||
'id': 'mock-entry',
|
|
||||||
'name': 'Mock Title',
|
|
||||||
})
|
|
||||||
# ---
|
|
||||||
# name: test_get_agent_info.1
|
|
||||||
dict({
|
dict({
|
||||||
'id': 'homeassistant',
|
'id': 'homeassistant',
|
||||||
'name': 'Home Assistant',
|
'name': 'Home Assistant',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_get_agent_info.2
|
# name: test_get_agent_info.1
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-entry',
|
'id': 'mock-entry',
|
||||||
'name': 'Mock Title',
|
'name': 'Mock Title',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_get_agent_info.2
|
||||||
|
dict({
|
||||||
|
'id': 'homeassistant',
|
||||||
|
'name': 'Home Assistant',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
# name: test_get_agent_info.3
|
# name: test_get_agent_info.3
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-entry',
|
'id': 'mock-entry',
|
||||||
@ -344,10 +344,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_get_agent_info
|
# name: test_ws_get_agent_info
|
||||||
dict({
|
dict({
|
||||||
'attribution': dict({
|
'attribution': None,
|
||||||
'name': 'Mock assistant',
|
|
||||||
'url': 'https://assist.me',
|
|
||||||
}),
|
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_ws_get_agent_info.1
|
# name: test_ws_get_agent_info.1
|
||||||
|
@ -1054,16 +1054,16 @@ async def test_http_api_wrong_data(
|
|||||||
assert resp.status == HTTPStatus.BAD_REQUEST
|
assert resp.status == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", (None, "mock-entry"))
|
|
||||||
async def test_custom_agent(
|
async def test_custom_agent(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
hass_admin_user: MockUser,
|
hass_admin_user: MockUser,
|
||||||
mock_agent,
|
mock_agent,
|
||||||
agent_id,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test a custom conversation agent."""
|
"""Test a custom conversation agent."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
@ -1071,9 +1071,8 @@ async def test_custom_agent(
|
|||||||
"text": "Test Text",
|
"text": "Test Text",
|
||||||
"conversation_id": "test-conv-id",
|
"conversation_id": "test-conv-id",
|
||||||
"language": "test-language",
|
"language": "test-language",
|
||||||
|
"agent_id": mock_agent.agent_id,
|
||||||
}
|
}
|
||||||
if agent_id is not None:
|
|
||||||
data["agent_id"] = agent_id
|
|
||||||
|
|
||||||
resp = await client.post("/api/conversation/process", json=data)
|
resp = await client.post("/api/conversation/process", json=data)
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
@ -1599,8 +1598,7 @@ async def test_get_agent_info(
|
|||||||
"""Test get agent info."""
|
"""Test get agent info."""
|
||||||
agent_info = conversation.async_get_agent_info(hass)
|
agent_info = conversation.async_get_agent_info(hass)
|
||||||
# Test it's the default
|
# Test it's the default
|
||||||
assert agent_info.id == mock_agent.agent_id
|
assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info
|
||||||
assert agent_info == snapshot
|
|
||||||
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot
|
||||||
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot
|
||||||
assert conversation.async_get_agent_info(hass, "not exist") is None
|
assert conversation.async_get_agent_info(hass, "not exist") is None
|
||||||
|
@ -16,7 +16,7 @@ from homeassistant.util.dt import utcnow
|
|||||||
|
|
||||||
from .conftest import ComponentSetup, ExpectedCredentials
|
from .conftest import ComponentSetup, ExpectedCredentials
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed, async_mock_service
|
from tests.common import MockConfigEntry, async_fire_time_changed, async_mock_service
|
||||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
@ -322,6 +322,7 @@ async def test_send_text_command_media_player(
|
|||||||
async def test_conversation_agent(
|
async def test_conversation_agent(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_integration: ComponentSetup,
|
setup_integration: ComponentSetup,
|
||||||
|
config_entry: MockConfigEntry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test GoogleAssistantConversationAgent."""
|
"""Test GoogleAssistantConversationAgent."""
|
||||||
await setup_integration()
|
await setup_integration()
|
||||||
@ -348,13 +349,13 @@ async def test_conversation_agent(
|
|||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
"conversation",
|
"conversation",
|
||||||
"process",
|
"process",
|
||||||
{"text": text1},
|
{"text": text1, "agent_id": config_entry.entry_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
"conversation",
|
"conversation",
|
||||||
"process",
|
"process",
|
||||||
{"text": text2},
|
{"text": text2, "agent_id": config_entry.entry_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -367,6 +368,7 @@ async def test_conversation_agent(
|
|||||||
|
|
||||||
async def test_conversation_agent_refresh_token(
|
async def test_conversation_agent_refresh_token(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
config_entry: MockConfigEntry,
|
||||||
setup_integration: ComponentSetup,
|
setup_integration: ComponentSetup,
|
||||||
aioclient_mock: AiohttpClientMocker,
|
aioclient_mock: AiohttpClientMocker,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -392,7 +394,7 @@ async def test_conversation_agent_refresh_token(
|
|||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
"conversation",
|
"conversation",
|
||||||
"process",
|
"process",
|
||||||
{"text": text1},
|
{"text": text1, "agent_id": config_entry.entry_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -412,7 +414,7 @@ async def test_conversation_agent_refresh_token(
|
|||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
"conversation",
|
"conversation",
|
||||||
"process",
|
"process",
|
||||||
{"text": text2},
|
{"text": text2, "agent_id": config_entry.entry_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
async def test_default_prompt(
|
async def test_default_prompt(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
area_registry: ar.AreaRegistry,
|
area_registry: ar.AreaRegistry,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
@ -89,16 +90,22 @@ async def test_default_prompt(
|
|||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
with patch("google.generativeai.chat_async") as mock_chat:
|
with patch("google.generativeai.chat_async") as mock_chat:
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
assert mock_chat.mock_calls[0][2] == snapshot
|
assert mock_chat.mock_calls[0][2] == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
async def test_error_handling(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
with patch("google.generativeai.chat_async", side_effect=ClientError("")):
|
with patch("google.generativeai.chat_async", side_effect=ClientError("")):
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
@ -119,7 +126,9 @@ async def test_template_error(
|
|||||||
), patch("google.generativeai.chat_async"):
|
), patch("google.generativeai.chat_async"):
|
||||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
|
@ -1026,15 +1026,19 @@ async def test_webhook_handle_conversation_process(
|
|||||||
"""Test that we can converse."""
|
"""Test that we can converse."""
|
||||||
webhook_client.server.app.router._frozen = False
|
webhook_client.server.app.router._frozen = False
|
||||||
|
|
||||||
resp = await webhook_client.post(
|
with patch(
|
||||||
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),
|
"homeassistant.components.conversation.AgentManager.async_get_agent",
|
||||||
json={
|
return_value=mock_agent,
|
||||||
"type": "conversation_process",
|
):
|
||||||
"data": {
|
resp = await webhook_client.post(
|
||||||
"text": "Turn the kitchen light off",
|
"/api/webhook/{}".format(create_registrations[1]["webhook_id"]),
|
||||||
|
json={
|
||||||
|
"type": "conversation_process",
|
||||||
|
"data": {
|
||||||
|
"text": "Turn the kitchen light off",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
json = await resp.json()
|
json = await resp.json()
|
||||||
|
@ -13,6 +13,7 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
async def test_default_prompt(
|
async def test_default_prompt(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
area_registry: ar.AreaRegistry,
|
area_registry: ar.AreaRegistry,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
@ -101,18 +102,24 @@ async def test_default_prompt(
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
) as mock_create:
|
) as mock_create:
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
async def test_error_handling(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
with patch(
|
with patch(
|
||||||
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
|
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
|
||||||
):
|
):
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
@ -133,7 +140,9 @@ async def test_template_error(
|
|||||||
), patch("openai.ChatCompletion.acreate"):
|
), patch("openai.ChatCompletion.acreate"):
|
||||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
|
Loading…
x
Reference in New Issue
Block a user