Make MockChatLog reusable for other integrations (#138112)

* Make MockChatLog reusable for other integrations

* Update tests/components/conversation/__init__.py
This commit is contained in:
Paulus Schoutsen 2025-02-09 20:19:28 -05:00 committed by GitHub
parent cabb406270
commit fa3acde684
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 90 deletions

View File

@ -2,7 +2,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation.models import (
@ -14,7 +18,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
async_expose_entity,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers import chat_session, intent
class MockAgent(conversation.AbstractConversationAgent):
@ -44,6 +48,53 @@ class MockAgent(conversation.AbstractConversationAgent):
)
@pytest.fixture
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
"""Return mock chat logs."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
with (
patch(
"homeassistant.components.conversation.chat_log.ChatLog",
MockChatLog,
),
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
conversation.async_get_chat_log(hass, session) as chat_log,
):
yield chat_log
@dataclass
class MockChatLog(conversation.ChatLog):
"""Mock chat log."""
_mock_tool_results: dict = field(default_factory=dict)
def mock_tool_results(self, results: dict) -> None:
"""Set tool results."""
self._mock_tool_results = results
@property
def llm_api(self):
"""Return LLM API."""
return self._llm_api
@llm_api.setter
def llm_api(self, value):
"""Set LLM API."""
self._llm_api = value
if not value:
return
async def async_call_tool(tool_input):
"""Call tool."""
if tool_input.id not in self._mock_tool_results:
raise ValueError(f"Tool {tool_input.id} not found")
return self._mock_tool_results[tool_input.id]
self._llm_api.async_call_tool = async_call_tool
def expose_new(hass: HomeAssistant, expose_new: bool) -> None:
"""Enable exposing new entities to the default agent."""
exposed_entities = hass.data[DATA_EXPOSED_ENTITIES]

View File

@ -1,20 +1,6 @@
# serializer version: 1
# name: test_function_call
list([
dict({
'content': '''
Current time is 16:00:00. Today's date is 2024-06-03.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({
'content': 'Please call the test function',
'role': 'user',

View File

@ -1,10 +1,8 @@
"""Tests for the OpenAI integration."""
from collections.abc import Generator
from dataclasses import dataclass, field
from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
from httpx import Response
from openai import RateLimitError
from openai.types.chat.chat_completion_chunk import (
@ -18,14 +16,17 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components.conversation import chat_log
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
mock_chat_log, # noqa: F401
)
ASSIST_RESPONSE_FINISH = (
# Assistant message
@ -66,66 +67,6 @@ def mock_create_stream() -> Generator[AsyncMock]:
yield mock_create
@dataclass
class MockChatLog(chat_log.ChatLog):
"""Mock chat log."""
_mock_tool_results: dict = field(default_factory=dict)
def mock_tool_results(self, results: dict) -> None:
"""Set tool results."""
self._mock_tool_results = results
@property
def llm_api(self):
"""Return LLM API."""
return self._llm_api
@llm_api.setter
def llm_api(self, value):
"""Set LLM API."""
self._llm_api = value
if not value:
return
async def async_call_tool(tool_input):
"""Call tool."""
if tool_input.id not in self._mock_tool_results:
raise ValueError(f"Tool {tool_input.id} not found")
return self._mock_tool_results[tool_input.id]
self._llm_api.async_call_tool = async_call_tool
def latest_content(self) -> list[conversation.Content]:
"""Return content from latest version chat log.
The chat log makes copies until it's committed. Helper to get latest content.
"""
with (
chat_session.async_get_chat_session(
self.hass, self.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
return chat_log.content
@pytest.fixture
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
"""Return mock chat logs."""
with (
patch(
"homeassistant.components.conversation.chat_log.ChatLog",
MockChatLog,
),
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
conversation.async_get_chat_log(hass, session) as chat_log,
):
chat_log.async_add_user_content(conversation.UserContent("hello"))
return chat_log
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@ -189,7 +130,7 @@ async def test_function_call(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog,
mock_chat_log: MockChatLog, # noqa: F811
snapshot: SnapshotAssertion,
) -> None:
"""Test function call from the assistant."""
@ -309,17 +250,17 @@ async def test_function_call(
}
)
with freeze_time("2024-06-03 23:00:00"):
result = await conversation.async_converse(
hass,
"Please call the test function",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
result = await conversation.async_converse(
hass,
"Please call the test function",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_chat_log.latest_content() == snapshot
# Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot
@pytest.mark.parametrize(
@ -430,7 +371,7 @@ async def test_function_call_invalid(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog,
mock_chat_log: MockChatLog, # noqa: F811
description: str,
messages: tuple[ChatCompletionChunk],
) -> None: