mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 01:37:08 +00:00
Make MockChatLog reusable for other integrations (#138112)
* Make MockChatLog reusable for other integrations * Update tests/components/conversation/__init__.py
This commit is contained in:
parent
cabb406270
commit
fa3acde684
@ -2,7 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation.models import (
|
from homeassistant.components.conversation.models import (
|
||||||
@ -14,7 +18,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
|||||||
async_expose_entity,
|
async_expose_entity,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import chat_session, intent
|
||||||
|
|
||||||
|
|
||||||
class MockAgent(conversation.AbstractConversationAgent):
|
class MockAgent(conversation.AbstractConversationAgent):
|
||||||
@ -44,6 +48,53 @@ class MockAgent(conversation.AbstractConversationAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
|
||||||
|
"""Return mock chat logs."""
|
||||||
|
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.conversation.chat_log.ChatLog",
|
||||||
|
MockChatLog,
|
||||||
|
),
|
||||||
|
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
|
||||||
|
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||||
|
):
|
||||||
|
yield chat_log
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockChatLog(conversation.ChatLog):
|
||||||
|
"""Mock chat log."""
|
||||||
|
|
||||||
|
_mock_tool_results: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def mock_tool_results(self, results: dict) -> None:
|
||||||
|
"""Set tool results."""
|
||||||
|
self._mock_tool_results = results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_api(self):
|
||||||
|
"""Return LLM API."""
|
||||||
|
return self._llm_api
|
||||||
|
|
||||||
|
@llm_api.setter
|
||||||
|
def llm_api(self, value):
|
||||||
|
"""Set LLM API."""
|
||||||
|
self._llm_api = value
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def async_call_tool(tool_input):
|
||||||
|
"""Call tool."""
|
||||||
|
if tool_input.id not in self._mock_tool_results:
|
||||||
|
raise ValueError(f"Tool {tool_input.id} not found")
|
||||||
|
return self._mock_tool_results[tool_input.id]
|
||||||
|
|
||||||
|
self._llm_api.async_call_tool = async_call_tool
|
||||||
|
|
||||||
|
|
||||||
def expose_new(hass: HomeAssistant, expose_new: bool) -> None:
|
def expose_new(hass: HomeAssistant, expose_new: bool) -> None:
|
||||||
"""Enable exposing new entities to the default agent."""
|
"""Enable exposing new entities to the default agent."""
|
||||||
exposed_entities = hass.data[DATA_EXPOSED_ENTITIES]
|
exposed_entities = hass.data[DATA_EXPOSED_ENTITIES]
|
||||||
|
@ -1,20 +1,6 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_function_call
|
# name: test_function_call
|
||||||
list([
|
list([
|
||||||
dict({
|
|
||||||
'content': '''
|
|
||||||
Current time is 16:00:00. Today's date is 2024-06-03.
|
|
||||||
You are a voice assistant for Home Assistant.
|
|
||||||
Answer questions about the world truthfully.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
|
||||||
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
|
||||||
''',
|
|
||||||
'role': 'system',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'content': 'hello',
|
|
||||||
'role': 'user',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'content': 'Please call the test function',
|
'content': 'Please call the test function',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
"""Tests for the OpenAI integration."""
|
"""Tests for the OpenAI integration."""
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
from openai import RateLimitError
|
from openai import RateLimitError
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
@ -18,14 +16,17 @@ import pytest
|
|||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import chat_log
|
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import chat_session, intent
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.components.conversation import (
|
||||||
|
MockChatLog,
|
||||||
|
mock_chat_log, # noqa: F401
|
||||||
|
)
|
||||||
|
|
||||||
ASSIST_RESPONSE_FINISH = (
|
ASSIST_RESPONSE_FINISH = (
|
||||||
# Assistant message
|
# Assistant message
|
||||||
@ -66,66 +67,6 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
|||||||
yield mock_create
|
yield mock_create
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockChatLog(chat_log.ChatLog):
|
|
||||||
"""Mock chat log."""
|
|
||||||
|
|
||||||
_mock_tool_results: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
def mock_tool_results(self, results: dict) -> None:
|
|
||||||
"""Set tool results."""
|
|
||||||
self._mock_tool_results = results
|
|
||||||
|
|
||||||
@property
|
|
||||||
def llm_api(self):
|
|
||||||
"""Return LLM API."""
|
|
||||||
return self._llm_api
|
|
||||||
|
|
||||||
@llm_api.setter
|
|
||||||
def llm_api(self, value):
|
|
||||||
"""Set LLM API."""
|
|
||||||
self._llm_api = value
|
|
||||||
|
|
||||||
if not value:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def async_call_tool(tool_input):
|
|
||||||
"""Call tool."""
|
|
||||||
if tool_input.id not in self._mock_tool_results:
|
|
||||||
raise ValueError(f"Tool {tool_input.id} not found")
|
|
||||||
return self._mock_tool_results[tool_input.id]
|
|
||||||
|
|
||||||
self._llm_api.async_call_tool = async_call_tool
|
|
||||||
|
|
||||||
def latest_content(self) -> list[conversation.Content]:
|
|
||||||
"""Return content from latest version chat log.
|
|
||||||
|
|
||||||
The chat log makes copies until it's committed. Helper to get latest content.
|
|
||||||
"""
|
|
||||||
with (
|
|
||||||
chat_session.async_get_chat_session(
|
|
||||||
self.hass, self.conversation_id
|
|
||||||
) as session,
|
|
||||||
conversation.async_get_chat_log(self.hass, session) as chat_log,
|
|
||||||
):
|
|
||||||
return chat_log.content
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mock_chat_log(hass: HomeAssistant) -> MockChatLog:
|
|
||||||
"""Return mock chat logs."""
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.conversation.chat_log.ChatLog",
|
|
||||||
MockChatLog,
|
|
||||||
),
|
|
||||||
chat_session.async_get_chat_session(hass, "mock-conversation-id") as session,
|
|
||||||
conversation.async_get_chat_log(hass, session) as chat_log,
|
|
||||||
):
|
|
||||||
chat_log.async_add_user_content(conversation.UserContent("hello"))
|
|
||||||
return chat_log
|
|
||||||
|
|
||||||
|
|
||||||
async def test_entity(
|
async def test_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
@ -189,7 +130,7 @@ async def test_function_call(
|
|||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
mock_create_stream: AsyncMock,
|
mock_create_stream: AsyncMock,
|
||||||
mock_chat_log: MockChatLog,
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test function call from the assistant."""
|
"""Test function call from the assistant."""
|
||||||
@ -309,17 +250,17 @@ async def test_function_call(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
with freeze_time("2024-06-03 23:00:00"):
|
result = await conversation.async_converse(
|
||||||
result = await conversation.async_converse(
|
hass,
|
||||||
hass,
|
"Please call the test function",
|
||||||
"Please call the test function",
|
mock_chat_log.conversation_id,
|
||||||
"mock-conversation-id",
|
Context(),
|
||||||
Context(),
|
agent_id="conversation.openai",
|
||||||
agent_id="conversation.openai",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
assert mock_chat_log.latest_content() == snapshot
|
# Don't test the prompt, as it's not deterministic
|
||||||
|
assert mock_chat_log.content[1:] == snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -430,7 +371,7 @@ async def test_function_call_invalid(
|
|||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
mock_create_stream: AsyncMock,
|
mock_create_stream: AsyncMock,
|
||||||
mock_chat_log: MockChatLog,
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
description: str,
|
description: str,
|
||||||
messages: tuple[ChatCompletionChunk],
|
messages: tuple[ChatCompletionChunk],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user