Make MockChatLog reusable for other integrations (#138112)

* Make MockChatLog reusable for other integrations

* Update tests/components/conversation/__init__.py
This commit is contained in:
Paulus Schoutsen 2025-02-09 20:19:28 -05:00 committed by GitHub
parent cabb406270
commit fa3acde684
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 90 deletions

View File

@ -2,7 +2,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal from typing import Literal
from unittest.mock import patch
import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation.models import ( from homeassistant.components.conversation.models import (
@ -14,7 +18,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
async_expose_entity, async_expose_entity,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent from homeassistant.helpers import chat_session, intent
class MockAgent(conversation.AbstractConversationAgent): class MockAgent(conversation.AbstractConversationAgent):
@ -44,6 +48,53 @@ class MockAgent(conversation.AbstractConversationAgent):
) )
@pytest.fixture
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
"""Return mock chat logs."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
with (
patch(
"homeassistant.components.conversation.chat_log.ChatLog",
MockChatLog,
),
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
conversation.async_get_chat_log(hass, session) as chat_log,
):
yield chat_log
@dataclass
class MockChatLog(conversation.ChatLog):
"""Mock chat log."""
_mock_tool_results: dict = field(default_factory=dict)
def mock_tool_results(self, results: dict) -> None:
"""Set tool results."""
self._mock_tool_results = results
@property
def llm_api(self):
"""Return LLM API."""
return self._llm_api
@llm_api.setter
def llm_api(self, value):
"""Set LLM API."""
self._llm_api = value
if not value:
return
async def async_call_tool(tool_input):
"""Call tool."""
if tool_input.id not in self._mock_tool_results:
raise ValueError(f"Tool {tool_input.id} not found")
return self._mock_tool_results[tool_input.id]
self._llm_api.async_call_tool = async_call_tool
def expose_new(hass: HomeAssistant, expose_new: bool) -> None: def expose_new(hass: HomeAssistant, expose_new: bool) -> None:
"""Enable exposing new entities to the default agent.""" """Enable exposing new entities to the default agent."""
exposed_entities = hass.data[DATA_EXPOSED_ENTITIES] exposed_entities = hass.data[DATA_EXPOSED_ENTITIES]

View File

@ -1,20 +1,6 @@
# serializer version: 1 # serializer version: 1
# name: test_function_call # name: test_function_call
list([ list([
dict({
'content': '''
Current time is 16:00:00. Today's date is 2024-06-03.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({ dict({
'content': 'Please call the test function', 'content': 'Please call the test function',
'role': 'user', 'role': 'user',

View File

@ -1,10 +1,8 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass, field
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
from httpx import Response from httpx import Response
from openai import RateLimitError from openai import RateLimitError
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
@ -18,14 +16,17 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import chat_log
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
mock_chat_log, # noqa: F401
)
ASSIST_RESPONSE_FINISH = ( ASSIST_RESPONSE_FINISH = (
# Assistant message # Assistant message
@ -66,66 +67,6 @@ def mock_create_stream() -> Generator[AsyncMock]:
yield mock_create yield mock_create
@dataclass
class MockChatLog(chat_log.ChatLog):
"""Mock chat log."""
_mock_tool_results: dict = field(default_factory=dict)
def mock_tool_results(self, results: dict) -> None:
"""Set tool results."""
self._mock_tool_results = results
@property
def llm_api(self):
"""Return LLM API."""
return self._llm_api
@llm_api.setter
def llm_api(self, value):
"""Set LLM API."""
self._llm_api = value
if not value:
return
async def async_call_tool(tool_input):
"""Call tool."""
if tool_input.id not in self._mock_tool_results:
raise ValueError(f"Tool {tool_input.id} not found")
return self._mock_tool_results[tool_input.id]
self._llm_api.async_call_tool = async_call_tool
def latest_content(self) -> list[conversation.Content]:
"""Return content from latest version chat log.
The chat log makes copies until it's committed. Helper to get latest content.
"""
with (
chat_session.async_get_chat_session(
self.hass, self.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
return chat_log.content
@pytest.fixture
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
"""Return mock chat logs."""
with (
patch(
"homeassistant.components.conversation.chat_log.ChatLog",
MockChatLog,
),
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
conversation.async_get_chat_log(hass, session) as chat_log,
):
chat_log.async_add_user_content(conversation.UserContent("hello"))
return chat_log
async def test_entity( async def test_entity(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -189,7 +130,7 @@ async def test_function_call(
mock_config_entry_with_assist: MockConfigEntry, mock_config_entry_with_assist: MockConfigEntry,
mock_init_component, mock_init_component,
mock_create_stream: AsyncMock, mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog, mock_chat_log: MockChatLog, # noqa: F811
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test function call from the assistant.""" """Test function call from the assistant."""
@ -309,17 +250,17 @@ async def test_function_call(
} }
) )
with freeze_time("2024-06-03 23:00:00"): result = await conversation.async_converse(
result = await conversation.async_converse( hass,
hass, "Please call the test function",
"Please call the test function", mock_chat_log.conversation_id,
"mock-conversation-id", Context(),
Context(), agent_id="conversation.openai",
agent_id="conversation.openai", )
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_chat_log.latest_content() == snapshot # Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -430,7 +371,7 @@ async def test_function_call_invalid(
mock_config_entry_with_assist: MockConfigEntry, mock_config_entry_with_assist: MockConfigEntry,
mock_init_component, mock_init_component,
mock_create_stream: AsyncMock, mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog, mock_chat_log: MockChatLog, # noqa: F811
description: str, description: str,
messages: tuple[ChatCompletionChunk], messages: tuple[ChatCompletionChunk],
) -> None: ) -> None: