mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
Make MockChatLog reusable for other integrations (#138112)
* Make MockChatLog reusable for other integrations * Update tests/components/conversation/__init__.py
This commit is contained in:
parent
cabb406270
commit
fa3acde684
@ -2,7 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation.models import (
|
||||
@ -14,7 +18,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
||||
async_expose_entity,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
|
||||
|
||||
class MockAgent(conversation.AbstractConversationAgent):
|
||||
@ -44,6 +48,53 @@ class MockAgent(conversation.AbstractConversationAgent):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
|
||||
"""Return mock chat logs."""
|
||||
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.chat_log.ChatLog",
|
||||
MockChatLog,
|
||||
),
|
||||
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
|
||||
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
yield chat_log
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChatLog(conversation.ChatLog):
|
||||
"""Mock chat log."""
|
||||
|
||||
_mock_tool_results: dict = field(default_factory=dict)
|
||||
|
||||
def mock_tool_results(self, results: dict) -> None:
|
||||
"""Set tool results."""
|
||||
self._mock_tool_results = results
|
||||
|
||||
@property
|
||||
def llm_api(self):
|
||||
"""Return LLM API."""
|
||||
return self._llm_api
|
||||
|
||||
@llm_api.setter
|
||||
def llm_api(self, value):
|
||||
"""Set LLM API."""
|
||||
self._llm_api = value
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
async def async_call_tool(tool_input):
|
||||
"""Call tool."""
|
||||
if tool_input.id not in self._mock_tool_results:
|
||||
raise ValueError(f"Tool {tool_input.id} not found")
|
||||
return self._mock_tool_results[tool_input.id]
|
||||
|
||||
self._llm_api.async_call_tool = async_call_tool
|
||||
|
||||
|
||||
def expose_new(hass: HomeAssistant, expose_new: bool) -> None:
|
||||
"""Enable exposing new entities to the default agent."""
|
||||
exposed_entities = hass.data[DATA_EXPOSED_ENTITIES]
|
||||
|
@ -1,20 +1,6 @@
|
||||
# serializer version: 1
|
||||
# name: test_function_call
|
||||
list([
|
||||
dict({
|
||||
'content': '''
|
||||
Current time is 16:00:00. Today's date is 2024-06-03.
|
||||
You are a voice assistant for Home Assistant.
|
||||
Answer questions about the world truthfully.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||
''',
|
||||
'role': 'system',
|
||||
}),
|
||||
dict({
|
||||
'content': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Please call the test function',
|
||||
'role': 'user',
|
||||
|
@ -1,10 +1,8 @@
|
||||
"""Tests for the OpenAI integration."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from httpx import Response
|
||||
from openai import RateLimitError
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
@ -18,14 +16,17 @@ import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import chat_log
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.conversation import (
|
||||
MockChatLog,
|
||||
mock_chat_log, # noqa: F401
|
||||
)
|
||||
|
||||
ASSIST_RESPONSE_FINISH = (
|
||||
# Assistant message
|
||||
@ -66,66 +67,6 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
||||
yield mock_create
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChatLog(chat_log.ChatLog):
|
||||
"""Mock chat log."""
|
||||
|
||||
_mock_tool_results: dict = field(default_factory=dict)
|
||||
|
||||
def mock_tool_results(self, results: dict) -> None:
|
||||
"""Set tool results."""
|
||||
self._mock_tool_results = results
|
||||
|
||||
@property
|
||||
def llm_api(self):
|
||||
"""Return LLM API."""
|
||||
return self._llm_api
|
||||
|
||||
@llm_api.setter
|
||||
def llm_api(self, value):
|
||||
"""Set LLM API."""
|
||||
self._llm_api = value
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
async def async_call_tool(tool_input):
|
||||
"""Call tool."""
|
||||
if tool_input.id not in self._mock_tool_results:
|
||||
raise ValueError(f"Tool {tool_input.id} not found")
|
||||
return self._mock_tool_results[tool_input.id]
|
||||
|
||||
self._llm_api.async_call_tool = async_call_tool
|
||||
|
||||
def latest_content(self) -> list[conversation.Content]:
|
||||
"""Return content from latest version chat log.
|
||||
|
||||
The chat log makes copies until it's committed. Helper to get latest content.
|
||||
"""
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, self.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
||||
):
|
||||
return chat_log.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
|
||||
"""Return mock chat logs."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.chat_log.ChatLog",
|
||||
MockChatLog,
|
||||
),
|
||||
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
|
||||
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
chat_log.async_add_user_content(conversation.UserContent("hello"))
|
||||
return chat_log
|
||||
|
||||
|
||||
async def test_entity(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
@ -189,7 +130,7 @@ async def test_function_call(
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function call from the assistant."""
|
||||
@ -309,17 +250,17 @@ async def test_function_call(
|
||||
}
|
||||
)
|
||||
|
||||
with freeze_time("2024-06-03 23:00:00"):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
"mock-conversation-id",
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_chat_log.latest_content() == snapshot
|
||||
# Don't test the prompt, as it's not deterministic
|
||||
assert mock_chat_log.content[1:] == snapshot
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -430,7 +371,7 @@ async def test_function_call_invalid(
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
description: str,
|
||||
messages: tuple[ChatCompletionChunk],
|
||||
) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user