From fa3acde684bb11f062b6198ecc0b35d0bdb36303 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 9 Feb 2025 20:19:28 -0500 Subject: [PATCH] Make MockChatLog reusable for other integrations (#138112) * Make MockChatLog reusable for other integrations * Update tests/components/conversation/__init__.py --- tests/components/conversation/__init__.py | 53 ++++++++++- .../snapshots/test_conversation.ambr | 14 --- .../openai_conversation/test_conversation.py | 91 ++++--------------- 3 files changed, 68 insertions(+), 90 deletions(-) diff --git a/tests/components/conversation/__init__.py b/tests/components/conversation/__init__.py index 1ae3372968e..314188dbd82 100644 --- a/tests/components/conversation/__init__.py +++ b/tests/components/conversation/__init__.py @@ -2,7 +2,11 @@ from __future__ import annotations +from dataclasses import dataclass, field from typing import Literal +from unittest.mock import patch + +import pytest from homeassistant.components import conversation from homeassistant.components.conversation.models import ( @@ -14,7 +18,7 @@ from homeassistant.components.homeassistant.exposed_entities import ( async_expose_entity, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers import intent +from homeassistant.helpers import chat_session, intent class MockAgent(conversation.AbstractConversationAgent): @@ -44,6 +48,53 @@ class MockAgent(conversation.AbstractConversationAgent): ) +@pytest.fixture +async def mock_chat_log(hass: HomeAssistant) -> MockChatLog: + """Return mock chat logs.""" + # pylint: disable-next=contextmanager-generator-missing-cleanup + with ( + patch( + "homeassistant.components.conversation.chat_log.ChatLog", + MockChatLog, + ), + chat_session.async_get_chat_session(hass, "mock-conversation-id") as session, + conversation.async_get_chat_log(hass, session) as chat_log, + ): + yield chat_log + + +@dataclass +class MockChatLog(conversation.ChatLog): + """Mock chat log.""" + + _mock_tool_results: dict = field(default_factory=dict) + + def mock_tool_results(self, results: dict) -> None: + """Set tool results.""" + self._mock_tool_results = results + + @property + def llm_api(self): + """Return LLM API.""" + return self._llm_api + + @llm_api.setter + def llm_api(self, value): + """Set LLM API.""" + self._llm_api = value + + if not value: + return + + async def async_call_tool(tool_input): + """Call tool.""" + if tool_input.id not in self._mock_tool_results: + raise ValueError(f"Tool {tool_input.id} not found") + return self._mock_tool_results[tool_input.id] + + self._llm_api.async_call_tool = async_call_tool + + def expose_new(hass: HomeAssistant, expose_new: bool) -> None: """Enable exposing new entities to the default agent.""" exposed_entities = hass.data[DATA_EXPOSED_ENTITIES] diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index 2db5be706ef..77c28de2773 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -1,20 +1,6 @@ # serializer version: 1 # name: test_function_call list([ - dict({ - 'content': ''' - Current time is 16:00:00. Today's date is 2024-06-03. - You are a voice assistant for Home Assistant. - Answer questions about the world truthfully. - Answer in plain text. Keep it simple and to the point. - Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant. - ''', - 'role': 'system', - }), - dict({ - 'content': 'hello', - 'role': 'user', - }), dict({ 'content': 'Please call the test function', 'role': 'user', diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 9afdfc6a5a2..2c956b7e63f 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -1,10 +1,8 @@ """Tests for the OpenAI integration.""" from collections.abc import Generator -from dataclasses import dataclass, field from unittest.mock import AsyncMock, patch -from freezegun import freeze_time from httpx import Response from openai import RateLimitError from openai.types.chat.chat_completion_chunk import ( @@ -18,14 +16,17 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation -from homeassistant.components.conversation import chat_log from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import chat_session, intent +from homeassistant.helpers import intent from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry +from tests.components.conversation import ( + MockChatLog, + mock_chat_log, # noqa: F401 +) ASSIST_RESPONSE_FINISH = ( # Assistant message @@ -66,66 +67,6 @@ def mock_create_stream() -> Generator[AsyncMock]: yield mock_create -@dataclass -class MockChatLog(chat_log.ChatLog): - """Mock chat log.""" - - _mock_tool_results: dict = field(default_factory=dict) - - def mock_tool_results(self, results: dict) -> None: - """Set tool results.""" - self._mock_tool_results = results - - @property - def llm_api(self): - """Return LLM API.""" - return self._llm_api - - @llm_api.setter - def llm_api(self, value): - """Set LLM API.""" - self._llm_api = value - - if not value: - return - - async def async_call_tool(tool_input): - """Call tool.""" - if tool_input.id not in self._mock_tool_results: - raise ValueError(f"Tool {tool_input.id} not found") - return self._mock_tool_results[tool_input.id] - - self._llm_api.async_call_tool = async_call_tool - - def latest_content(self) -> list[conversation.Content]: - """Return content from latest version chat log. - - The chat log makes copies until it's committed. Helper to get latest content. - """ - with ( - chat_session.async_get_chat_session( - self.hass, self.conversation_id - ) as session, - conversation.async_get_chat_log(self.hass, session) as chat_log, - ): - return chat_log.content - - -@pytest.fixture -async def mock_chat_log(hass: HomeAssistant) -> MockChatLog: - """Return mock chat logs.""" - with ( - patch( - "homeassistant.components.conversation.chat_log.ChatLog", - MockChatLog, - ), - chat_session.async_get_chat_session(hass, "mock-conversation-id") as session, - conversation.async_get_chat_log(hass, session) as chat_log, - ): - chat_log.async_add_user_content(conversation.UserContent("hello")) - return chat_log - - async def test_entity( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -189,7 +130,7 @@ async def test_function_call( mock_config_entry_with_assist: MockConfigEntry, mock_init_component, mock_create_stream: AsyncMock, - mock_chat_log: MockChatLog, + mock_chat_log: MockChatLog, # noqa: F811 snapshot: SnapshotAssertion, ) -> None: """Test function call from the assistant.""" @@ -309,17 +250,17 @@ async def test_function_call( } ) - with freeze_time("2024-06-03 23:00:00"): - result = await conversation.async_converse( - hass, - "Please call the test function", - "mock-conversation-id", - Context(), - agent_id="conversation.openai", - ) + result = await conversation.async_converse( + hass, + "Please call the test function", + mock_chat_log.conversation_id, + Context(), + agent_id="conversation.openai", + ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert mock_chat_log.latest_content() == snapshot + # Don't test the prompt, as it's not deterministic + assert mock_chat_log.content[1:] == snapshot @pytest.mark.parametrize( @@ -430,7 +371,7 @@ async def test_function_call_invalid( mock_config_entry_with_assist: MockConfigEntry, mock_init_component, mock_create_stream: AsyncMock, - mock_chat_log: MockChatLog, + mock_chat_log: MockChatLog, # noqa: F811 description: str, messages: tuple[ChatCompletionChunk], ) -> None: