From 0c68854fdfeed9facd8caab2fb9002b5493cf881 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 19 Jan 2025 14:32:59 -0500 Subject: [PATCH] Migrate tests from OpenAI to conversation integration (#135963) --- .../components/conversation/session.py | 6 +- .../openai_conversation/conversation.py | 4 +- .../conversation/snapshots/test_session.ambr | 41 +++ tests/components/conversation/test_session.py | 244 +++++++++++++++++ .../openai_conversation/test_conversation.py | 249 +----------------- 5 files changed, 292 insertions(+), 252 deletions(-) create mode 100644 tests/components/conversation/snapshots/test_session.ambr diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/session.py index f9db80afa63..ba9e0ad6292 100644 --- a/homeassistant/components/conversation/session.py +++ b/homeassistant/components/conversation/session.py @@ -155,7 +155,7 @@ class ConverseError(HomeAssistantError): self.conversation_id = conversation_id self.response = response - def as_converstation_result(self) -> ConversationResult: + def as_conversation_result(self) -> ConversationResult: """Return the error as a conversation result.""" return ConversationResult( response=self.response, @@ -220,14 +220,14 @@ class ChatSession(Generic[_NativeT]): if message.role != "native" or message.agent_id == agent_id ] - async def async_process_llm_message( + async def async_update_llm_data( self, conversing_domain: str, user_input: ConversationInput, user_llm_hass_api: str | None = None, user_llm_prompt: str | None = None, ) -> None: - """Process an incoming message for an LLM.""" + """Set the LLM system prompt.""" llm_context = llm.LLMContext( platform=conversing_domain, context=user_input.context, diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 9a6b61e4c43..c89574bf3bd 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -157,14 +157,14 @@ class OpenAIConversationEntity( options = self.entry.options try: - await session.async_process_llm_message( + await session.async_update_llm_data( DOMAIN, user_input, options.get(CONF_LLM_HASS_API), options.get(CONF_PROMPT), ) except conversation.ConverseError as err: - return err.as_converstation_result() + return err.as_conversation_result() tools: list[ChatCompletionToolParam] | None = None if session.llm_api: diff --git a/tests/components/conversation/snapshots/test_session.ambr b/tests/components/conversation/snapshots/test_session.ambr new file mode 100644 index 00000000000..4e94157c601 --- /dev/null +++ b/tests/components/conversation/snapshots/test_session.ambr @@ -0,0 +1,41 @@ +# serializer version: 1 +# name: test_template_error + dict({ + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'unknown', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Sorry, I had a problem with my template', + }), + }), + }), + }) +# --- +# name: test_unknown_llm_api + dict({ + 'conversation_id': , + 'response': dict({ + 'card': dict({ + }), + 'data': dict({ + 'code': 'unknown', + }), + 'language': 'en', + 'response_type': 'error', + 'speech': dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Error preparing LLM API', + }), + }), + }), + }) +# --- diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_session.py index 45cb517528d..feb6ca2a9e8 100644 --- a/tests/components/conversation/test_session.py +++ b/tests/components/conversation/test_session.py @@ -5,9 +5,11 @@ from datetime import timedelta from unittest.mock import Mock, patch import pytest +from syrupy.assertion import SnapshotAssertion from homeassistant.components.conversation import ConversationInput, session from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import llm from homeassistant.util import dt as dt_util from tests.common import async_fire_time_changed @@ -94,10 +96,27 @@ async def test_cleanup( assert len(chat_session.messages) == 4 assert chat_session.conversation_id == conversation_id + # Set the last updated to be older than the timeout + hass.data[session.DATA_CHAT_HISTORY][conversation_id].last_updated = ( + dt_util.utcnow() + session.CONVERSATION_TIMEOUT + ) + async_fire_time_changed( hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1) ) + # Should not be cleaned up, but it should have scheduled another cleanup + mock_conversation_input.conversation_id = conversation_id + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + assert len(chat_session.messages) == 4 + assert chat_session.conversation_id == conversation_id + + async_fire_time_changed( + hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1) + ) + # It should be cleaned up now and we start a new conversation async with session.async_get_chat_session( hass, mock_conversation_input @@ -106,6 +125,47 @@ async def test_cleanup( assert len(chat_session.messages) == 2 +def test_chat_message() -> None: + """Test chat message.""" + with pytest.raises(ValueError): + session.ChatMessage(role="native", agent_id=None, content="", native=None) + + +async def test_add_message( + hass: HomeAssistant, mock_conversation_input: ConversationInput +) -> None: + """Test filtering of messages.""" + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + assert len(chat_session.messages) == 2 + + with pytest.raises(ValueError): + chat_session.async_add_message( + session.ChatMessage(role="system", agent_id=None, content="") + ) + + # No 2 user messages in a row + assert chat_session.messages[1].role == "user" + + with pytest.raises(ValueError): + chat_session.async_add_message( + session.ChatMessage(role="user", agent_id=None, content="") + ) + + # No 2 assistant messages in a row + chat_session.async_add_message( + session.ChatMessage(role="assistant", agent_id=None, content="") + ) + assert len(chat_session.messages) == 3 + assert chat_session.messages[-1].role == "assistant" + + with pytest.raises(ValueError): + chat_session.async_add_message( + session.ChatMessage(role="assistant", agent_id=None, content="") + ) + + async def test_message_filtering( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: @@ -169,3 +229,187 @@ async def test_message_filtering( assert messages[3] == session.ChatMessage( role="native", agent_id="mock-agent-id", content="", native=1 ) + + +async def test_llm_api( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test when we reference an LLM API.""" + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="assist", + user_llm_prompt=None, + ) + + assert isinstance(chat_session.llm_api, llm.APIInstance) + assert chat_session.llm_api.api.id == "assist" + + +async def test_unknown_llm_api( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, + snapshot: SnapshotAssertion, +) -> None: + """Test when we reference an LLM API that does not exists.""" + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + with pytest.raises(session.ConverseError) as exc_info: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="unknown-api", + user_llm_prompt=None, + ) + + assert str(exc_info.value) == "Error getting LLM API unknown-api" + assert exc_info.value.as_conversation_result().as_dict() == snapshot + + +async def test_template_error( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, + snapshot: SnapshotAssertion, +) -> None: + """Test that template error handling works.""" + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + with pytest.raises(session.ConverseError) as exc_info: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt="{{ invalid_syntax", + ) + + assert str(exc_info.value) == "Error rendering prompt" + assert exc_info.value.as_conversation_result().as_dict() == snapshot + + +async def test_template_variables( + hass: HomeAssistant, mock_conversation_input: ConversationInput +) -> None: + """Test that template variables work.""" + mock_user = Mock() + mock_user.id = "12345" + mock_user.name = "Test User" + mock_conversation_input.context = Context(user_id=mock_user.id) + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + with patch( + "homeassistant.auth.AuthManager.async_get_user", return_value=mock_user + ): + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt=( + "The instance name is {{ ha_name }}. " + "The user name is {{ user_name }}. " + "The user id is {{ llm_context.context.user_id }}." + "The calling platform is {{ llm_context.platform }}." + ), + ) + + assert chat_session.user_name == "Test User" + + assert "The instance name is test home." in chat_session.messages[0].content + assert "The user name is Test User." in chat_session.messages[0].content + assert "The user id is 12345." in chat_session.messages[0].content + assert "The calling platform is test." in chat_session.messages[0].content + + +async def test_extra_systen_prompt( + hass: HomeAssistant, mock_conversation_input: ConversationInput +) -> None: + """Test that extra system prompt works.""" + extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it." + extra_system_prompt2 = ( + "User person.paulus came home. Asked him what he wants to do." + ) + mock_conversation_input.extra_system_prompt = extra_system_prompt + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt=None, + ) + chat_session.async_add_message( + session.ChatMessage( + role="assistant", + agent_id="mock-agent-id", + content="Hey!", + ) + ) + + assert chat_session.extra_system_prompt == extra_system_prompt + assert chat_session.messages[0].content.endswith(extra_system_prompt) + + # Verify that follow-up conversations with no system prompt take previous one + mock_conversation_input.conversation_id = chat_session.conversation_id + mock_conversation_input.extra_system_prompt = None + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt=None, + ) + + assert chat_session.extra_system_prompt == extra_system_prompt + assert chat_session.messages[0].content.endswith(extra_system_prompt) + + # Verify that we take new system prompts + mock_conversation_input.extra_system_prompt = extra_system_prompt2 + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt=None, + ) + chat_session.async_add_message( + session.ChatMessage( + role="assistant", + agent_id="mock-agent-id", + content="Hey!", + ) + ) + + assert chat_session.extra_system_prompt == extra_system_prompt2 + assert chat_session.messages[0].content.endswith(extra_system_prompt2) + assert extra_system_prompt not in chat_session.messages[0].content + + # Verify that follow-up conversations with no system prompt take previous one + mock_conversation_input.extra_system_prompt = None + + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + await chat_session.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt=None, + ) + + assert chat_session.extra_system_prompt == extra_system_prompt2 + assert chat_session.messages[0].content.endswith(extra_system_prompt2) diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index b89ddcd8921..9ee19cd330c 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -1,6 +1,6 @@ """Tests for the OpenAI integration.""" -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch from freezegun import freeze_time from httpx import Response @@ -12,7 +12,6 @@ from openai.types.chat.chat_completion_message_tool_call import ( Function, ) from openai.types.completion_usage import CompletionUsage -from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation @@ -22,7 +21,6 @@ from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import intent, llm from homeassistant.setup import async_setup_component -from homeassistant.util import ulid from tests.common import MockConfigEntry @@ -57,7 +55,7 @@ async def test_entity( async def test_error_handling( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component ) -> None: - """Test that the default prompt works.""" + """Test that we handle errors when calling completion API.""" with patch( "openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, @@ -73,183 +71,6 @@ async def test_error_handling( assert result.response.error_code == "unknown", result -async def test_template_error( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template error handling works.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", - }, - ) - with ( - patch( - "openai.resources.models.AsyncModels.list", - ), - patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ), - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_template_variables( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template variables work.""" - context = Context(user_id="12345") - mock_user = Mock() - mock_user.id = "12345" - mock_user.name = "Test User" - - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - "prompt": ( - "The user name is {{ user_name }}. " - "The user id is {{ llm_context.context.user_id }}." - ), - }, - ) - with ( - patch( - "openai.resources.models.AsyncModels.list", - ), - patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ) as mock_create, - patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, "hello", None, context, agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( - result - ) - assert ( - "The user name is Test User." - in mock_create.mock_calls[0][2]["messages"][0]["content"] - ) - assert ( - "The user id is 12345." - in mock_create.mock_calls[0][2]["messages"][0]["content"] - ) - - -async def test_extra_systen_prompt( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template variables work.""" - extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it." - extra_system_prompt2 = ( - "User person.paulus came home. Asked him what he wants to do." - ) - - with ( - patch( - "openai.resources.models.AsyncModels.list", - ), - patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ) as mock_create, - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, - "hello", - None, - Context(), - agent_id=mock_config_entry.entry_id, - extra_system_prompt=extra_system_prompt, - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( - result - ) - assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( - extra_system_prompt - ) - - conversation_id = result.conversation_id - - # Verify that follow-up conversations with no system prompt take previous one - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ) as mock_create: - result = await conversation.async_converse( - hass, - "hello", - conversation_id, - Context(), - agent_id=mock_config_entry.entry_id, - extra_system_prompt=None, - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( - result - ) - assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( - extra_system_prompt - ) - - # Verify that we take new system prompts - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ) as mock_create: - result = await conversation.async_converse( - hass, - "hello", - conversation_id, - Context(), - agent_id=mock_config_entry.entry_id, - extra_system_prompt=extra_system_prompt2, - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( - result - ) - assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( - extra_system_prompt2 - ) - - # Verify that follow-up conversations with no system prompt take previous one - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ) as mock_create: - result = await conversation.async_converse( - hass, - "hello", - conversation_id, - Context(), - agent_id=mock_config_entry.entry_id, - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( - result - ) - assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( - extra_system_prompt2 - ) - - async def test_conversation_agent( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -605,69 +426,3 @@ async def test_assist_api_tools_conversion( tools = mock_create.mock_calls[0][2]["tools"] assert tools - - -async def test_unknown_hass_api( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - snapshot: SnapshotAssertion, - mock_init_component, -) -> None: - """Test when we reference an API that no longer exists.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - **mock_config_entry.options, - CONF_LLM_HASS_API: "non-existing", - }, - ) - - await hass.async_block_till_done() - - result = await conversation.async_converse( - hass, - "hello", - "my-conversation-id", - Context(), - agent_id=mock_config_entry.entry_id, - ) - - assert result == snapshot - - -@patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, -) -async def test_conversation_id( - mock_create, - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, -) -> None: - """Test conversation ID is honored.""" - result = await conversation.async_converse( - hass, "hello", None, None, agent_id=mock_config_entry.entry_id - ) - - conversation_id = result.conversation_id - - result = await conversation.async_converse( - hass, "hello", conversation_id, None, agent_id=mock_config_entry.entry_id - ) - - assert result.conversation_id == conversation_id - - unknown_id = ulid.ulid() - - result = await conversation.async_converse( - hass, "hello", unknown_id, None, agent_id=mock_config_entry.entry_id - ) - - assert result.conversation_id != unknown_id - - result = await conversation.async_converse( - hass, "hello", "koala", None, agent_id=mock_config_entry.entry_id - ) - - assert result.conversation_id == "koala"