From d6cd5648b9e6dc7d7fd2892df1a17289157488a0 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 26 Jun 2023 22:10:17 -0400 Subject: [PATCH] Change conversation default agent behavior (#95225) * Change conversation default agent behavior * Fix tests --- .../components/conversation/__init__.py | 4 ---- .../conversation/snapshots/test_init.ambr | 19 ++++++++---------- tests/components/conversation/test_init.py | 10 ++++------ .../google_assistant_sdk/test_init.py | 12 ++++++----- .../test_init.py | 17 ++++++++++++---- tests/components/mobile_app/test_webhook.py | 20 +++++++++++-------- .../openai_conversation/test_init.py | 17 ++++++++++++---- 7 files changed, 57 insertions(+), 42 deletions(-) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index f3d883b1565..f704a8baa33 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -529,12 +529,8 @@ class AgentManager: def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> None: """Set the agent.""" self._agents[agent_id] = agent - if self.default_agent == HOME_ASSISTANT_AGENT: - self.default_agent = agent_id @core.callback def async_unset_agent(self, agent_id: str) -> None: """Unset the agent.""" - if self.default_agent == agent_id: - self.default_agent = HOME_ASSISTANT_AGENT self._agents.pop(agent_id, None) diff --git a/tests/components/conversation/snapshots/test_init.ambr b/tests/components/conversation/snapshots/test_init.ambr index f4325e2f291..afc2d2e4418 100644 --- a/tests/components/conversation/snapshots/test_init.ambr +++ b/tests/components/conversation/snapshots/test_init.ambr @@ -1,22 +1,22 @@ # serializer version: 1 # name: test_get_agent_info - dict({ - 'id': 'mock-entry', - 'name': 'Mock Title', - }) -# --- -# name: test_get_agent_info.1 dict({ 'id': 'homeassistant', 'name': 'Home Assistant', }) # --- -# name: test_get_agent_info.2 +# name: test_get_agent_info.1 dict({ 'id': 'mock-entry', 'name': 'Mock Title', }) # --- +# name: test_get_agent_info.2 + dict({ + 'id': 'homeassistant', + 'name': 'Home Assistant', + }) +# --- # name: test_get_agent_info.3 dict({ 'id': 'mock-entry', @@ -344,10 +344,7 @@ # --- # name: test_ws_get_agent_info dict({ - 'attribution': dict({ - 'name': 'Mock assistant', - 'url': 'https://assist.me', - }), + 'attribution': None, }) # --- # name: test_ws_get_agent_info.1 diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index b55bd651b9e..ec2128e3bd7 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -1054,16 +1054,16 @@ async def test_http_api_wrong_data( assert resp.status == HTTPStatus.BAD_REQUEST -@pytest.mark.parametrize("agent_id", (None, "mock-entry")) async def test_custom_agent( hass: HomeAssistant, hass_client: ClientSessionGenerator, hass_admin_user: MockUser, mock_agent, - agent_id, ) -> None: """Test a custom conversation agent.""" + assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "conversation", {}) + assert await async_setup_component(hass, "intent", {}) client = await hass_client() @@ -1071,9 +1071,8 @@ async def test_custom_agent( "text": "Test Text", "conversation_id": "test-conv-id", "language": "test-language", + "agent_id": mock_agent.agent_id, } - if agent_id is not None: - data["agent_id"] = agent_id resp = await client.post("/api/conversation/process", json=data) assert resp.status == HTTPStatus.OK @@ -1599,8 +1598,7 @@ async def test_get_agent_info( """Test get agent info.""" agent_info = conversation.async_get_agent_info(hass) # Test it's the default - assert agent_info.id == mock_agent.agent_id - assert agent_info == snapshot + assert conversation.async_get_agent_info(hass, "homeassistant") == agent_info assert conversation.async_get_agent_info(hass, "homeassistant") == snapshot assert conversation.async_get_agent_info(hass, mock_agent.agent_id) == snapshot assert conversation.async_get_agent_info(hass, "not exist") is None diff --git a/tests/components/google_assistant_sdk/test_init.py b/tests/components/google_assistant_sdk/test_init.py index 4cfdd42bcdd..25066f73b6d 100644 --- a/tests/components/google_assistant_sdk/test_init.py +++ b/tests/components/google_assistant_sdk/test_init.py @@ -16,7 +16,7 @@ from homeassistant.util.dt import utcnow from .conftest import ComponentSetup, ExpectedCredentials -from tests.common import async_fire_time_changed, async_mock_service +from tests.common import MockConfigEntry, async_fire_time_changed, async_mock_service from tests.test_util.aiohttp import AiohttpClientMocker from tests.typing import ClientSessionGenerator @@ -322,6 +322,7 @@ async def test_send_text_command_media_player( async def test_conversation_agent( hass: HomeAssistant, setup_integration: ComponentSetup, + config_entry: MockConfigEntry, ) -> None: """Test GoogleAssistantConversationAgent.""" await setup_integration() @@ -348,13 +349,13 @@ async def test_conversation_agent( await hass.services.async_call( "conversation", "process", - {"text": text1}, + {"text": text1, "agent_id": config_entry.entry_id}, blocking=True, ) await hass.services.async_call( "conversation", "process", - {"text": text2}, + {"text": text2, "agent_id": config_entry.entry_id}, blocking=True, ) @@ -367,6 +368,7 @@ async def test_conversation_agent( async def test_conversation_agent_refresh_token( hass: HomeAssistant, + config_entry: MockConfigEntry, setup_integration: ComponentSetup, aioclient_mock: AiohttpClientMocker, ) -> None: @@ -392,7 +394,7 @@ async def test_conversation_agent_refresh_token( await hass.services.async_call( "conversation", "process", - {"text": text1}, + {"text": text1, "agent_id": config_entry.entry_id}, blocking=True, ) @@ -412,7 +414,7 @@ async def test_conversation_agent_refresh_token( await hass.services.async_call( "conversation", "process", - {"text": text2}, + {"text": text2, "agent_id": config_entry.entry_id}, blocking=True, ) diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 7335903b43b..e8da4cf3920 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -13,6 +13,7 @@ from tests.common import MockConfigEntry async def test_default_prompt( hass: HomeAssistant, + mock_config_entry: MockConfigEntry, mock_init_component, area_registry: ar.AreaRegistry, device_registry: dr.DeviceRegistry, @@ -89,16 +90,22 @@ async def test_default_prompt( suggested_area="Test Area 2", ) with patch("google.generativeai.chat_async") as mock_chat: - result = await conversation.async_converse(hass, "hello", None, Context()) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert mock_chat.mock_calls[0][2] == snapshot -async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None: +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: """Test that the default prompt works.""" with patch("google.generativeai.chat_async", side_effect=ClientError("")): - result = await conversation.async_converse(hass, "hello", None, Context()) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result @@ -119,7 +126,9 @@ async def test_template_error( ), patch("google.generativeai.chat_async"): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() - result = await conversation.async_converse(hass, "hello", None, Context()) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result diff --git a/tests/components/mobile_app/test_webhook.py b/tests/components/mobile_app/test_webhook.py index 02c9ace7cd4..ce1dc19319a 100644 --- a/tests/components/mobile_app/test_webhook.py +++ b/tests/components/mobile_app/test_webhook.py @@ -1026,15 +1026,19 @@ async def test_webhook_handle_conversation_process( """Test that we can converse.""" webhook_client.server.app.router._frozen = False - resp = await webhook_client.post( - "/api/webhook/{}".format(create_registrations[1]["webhook_id"]), - json={ - "type": "conversation_process", - "data": { - "text": "Turn the kitchen light off", + with patch( + "homeassistant.components.conversation.AgentManager.async_get_agent", + return_value=mock_agent, + ): + resp = await webhook_client.post( + "/api/webhook/{}".format(create_registrations[1]["webhook_id"]), + json={ + "type": "conversation_process", + "data": { + "text": "Turn the kitchen light off", + }, }, - }, - ) + ) assert resp.status == HTTPStatus.OK json = await resp.json() diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index 4016ac03c97..fe23bbac56c 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -13,6 +13,7 @@ from tests.common import MockConfigEntry async def test_default_prompt( hass: HomeAssistant, + mock_config_entry: MockConfigEntry, mock_init_component, area_registry: ar.AreaRegistry, device_registry: dr.DeviceRegistry, @@ -101,18 +102,24 @@ async def test_default_prompt( ] }, ) as mock_create: - result = await conversation.async_converse(hass, "hello", None, Context()) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert mock_create.mock_calls[0][2]["messages"] == snapshot -async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None: +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: """Test that the default prompt works.""" with patch( "openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError ): - result = await conversation.async_converse(hass, "hello", None, Context()) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result @@ -133,7 +140,9 @@ async def test_template_error( ), patch("openai.ChatCompletion.acreate"): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() - result = await conversation.async_converse(hass, "hello", None, Context()) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result