Chat session rev2 (#137209)

* Chat Session rev 2

* Rename session to chat_log

* Simplify typing

* Typing

* Address comments

* Fix anthropic and ollama
This commit is contained in:
Paulus Schoutsen 2025-02-03 00:05:20 -05:00 committed by GitHub
parent ce93cb9467
commit 9679fc7878
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 388 additions and 330 deletions

View File

@ -272,6 +272,7 @@ class AnthropicConversationEntity(
continue
tool_input = llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.name,
tool_args=cast(dict[str, Any], tool_call.input),
)

View File

@ -1063,11 +1063,11 @@ class PipelineRun:
agent_id=self.intent_agent,
extra_system_prompt=conversation_extra_system_prompt,
)
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
agent_id = user_input.agent_id
agent_id = self.intent_agent
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
if not processed_locally:
# Sentence triggers override conversation agent
if (
trigger_response_text
@ -1105,13 +1105,13 @@ class PipelineRun:
speech: str = intent_response.speech.get("plain", {}).get(
"speech", ""
)
chat_log.async_add_message(
conversation.Content(
role="assistant",
async for _ in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=agent_id,
content=speech,
)
)
):
pass
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=session.conversation_id,

View File

@ -30,6 +30,16 @@ from .agent_manager import (
async_get_agent,
get_agent_manager,
)
from .chat_log import (
AssistantContent,
ChatLog,
Content,
ConverseError,
SystemContent,
ToolResultContent,
UserContent,
async_get_chat_log,
)
from .const import (
ATTR_AGENT_ID,
ATTR_CONVERSATION_ID,
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [
"DOMAIN",
"HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT",
"AssistantContent",
"ChatLog",
"Content",
"ConversationEntity",
@ -63,7 +73,9 @@ __all__ = [
"ConversationResult",
"ConversationTraceEventType",
"ConverseError",
"NativeContent",
"SystemContent",
"ToolResultContent",
"UserContent",
"async_conversation_trace_append",
"async_converse",
"async_get_agent_info",

View File

@ -2,19 +2,16 @@
from __future__ import annotations
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime
import logging
from typing import Literal
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
def async_get_chat_log(
hass: HomeAssistant,
session: chat_session.ChatSession,
user_input: ConversationInput,
user_input: ConversationInput | None = None,
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
@ -42,9 +39,9 @@ def async_get_chat_log(
history = all_history.get(session.conversation_id)
if history:
history = replace(history, messages=history.messages.copy())
history = replace(history, content=history.content.copy())
else:
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
history = ChatLog(hass, session.conversation_id)
@callback
def do_cleanup() -> None:
@ -53,22 +50,19 @@ def async_get_chat_log(
session.async_on_cleanup(do_cleanup)
message: Content = Content(
role="user",
agent_id=user_input.agent_id,
content=user_input.text,
)
history.async_add_message(message)
if user_input is not None:
history.async_add_user_content(UserContent(content=user_input.text))
last_message = history.content[-1]
yield history
if history.messages[-1] is message:
if history.content[-1] is last_message:
LOGGER.debug(
"History opened but no assistant message was added, ignoring update"
)
return
history.last_updated = dt_util.utcnow()
all_history[session.conversation_id] = history
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
)
@dataclass
class Content:
@dataclass(frozen=True)
class SystemContent:
"""Base class for chat messages."""
role: Literal["system", "assistant", "user"]
agent_id: str | None
role: str = field(init=False, default="system")
content: str
@dataclass(frozen=True)
class NativeContent[_NativeT]:
"""Native content."""
class UserContent:
"""Assistant content."""
role: str = field(init=False, default="native")
role: str = field(init=False, default="user")
content: str
@dataclass(frozen=True)
class AssistantContent:
"""Assistant content."""
role: str = field(init=False, default="assistant")
agent_id: str
content: _NativeT
content: str
tool_calls: list[llm.ToolInput] | None = None
@dataclass(frozen=True)
class ToolResultContent:
"""Tool result content."""
role: str = field(init=False, default="tool_result")
agent_id: str
tool_call_id: str
tool_name: str
tool_result: JsonObjectType
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@dataclass
class ChatLog[_NativeT]:
class ChatLog:
"""Class holding the chat history of a specific conversation."""
hass: HomeAssistant
conversation_id: str
agent_id: str | None
user_name: str | None = None
messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
)
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
"""Process intent."""
if message.role == "system":
raise ValueError("Cannot add system messages to history")
if message.role != "native" and self.messages[-1].role == message.role:
raise ValueError("Cannot add two assistant or user messages in a row")
def async_add_user_content(self, content: UserContent) -> None:
"""Add user content to the log."""
self.content.append(content)
self.messages.append(message)
async def async_add_assistant_content(
self, content: AssistantContent
) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content."""
self.content.append(content)
@callback
def async_get_messages(
self, agent_id: str | None = None
) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID.
if content.tool_calls is None:
return
This will filter out any native message tied to other agent IDs.
It can still include assistant/user messages generated by other agents.
"""
return [
message
for message in self.messages
if message.role != "native" or message.agent_id == agent_id
]
if self.llm_api is None:
raise ValueError("No LLM API configured")
for tool_input in content.tool_calls:
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_result = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__}
if str(e):
tool_result["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_result)
response_content = ToolResultContent(
agent_id=content.agent_id,
tool_call_id=tool_input.id,
tool_name=tool_input.tool_name,
tool_result=tool_result,
)
self.content.append(response_content)
yield response_content
async def async_update_llm_data(
self,
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
prompt = "\n".join(prompt_parts)
self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt
self.messages[0] = Content(
role="system",
agent_id=user_input.agent_id,
content=prompt,
)
self.content[0] = SystemContent(content=prompt)
LOGGER.debug("Prompt: %s", self.messages)
LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": self.messages,
"messages": self.content,
"tools": self.llm_api.tools if self.llm_api else None,
},
)
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
"""Invoke LLM tool for the configured LLM API."""
if not self.llm_api:
raise ValueError("No LLM API configured")
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
try:
tool_response = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
return tool_response

View File

@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util.json import JsonObjectType, json_loads_object
from .chat_log import AssistantContent, async_get_chat_log
from .const import (
DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES,
@ -63,7 +64,6 @@ from .const import (
)
from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult
from .session import Content, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__)
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
)
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_log.async_add_message(
Content(
role="assistant",
agent_id=user_input.agent_id,
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=user_input.agent_id, # type: ignore[arg-type]
content=speech,
)
)
):
pass
return ConversationResult(
response=response, conversation_id=session.conversation_id

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs
from collections.abc import Callable
from typing import Any, Literal
from typing import Any, Literal, cast
from google.api_core.exceptions import GoogleAPIError
import google.generativeai as genai
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
return value
def _chat_message_convert(
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
) -> genai_types.ContentDict:
"""Convert any native chat message for this agent to the native format."""
if message.role == "native":
return message.content
def _create_google_tool_response_content(
content: list[conversation.ToolResultContent],
) -> protos.Content:
"""Create a Google tool response content."""
return protos.Content(
parts=[
protos.Part(
function_response=protos.FunctionResponse(
name=tool_result.tool_name, response=tool_result.tool_result
)
)
for tool_result in content
]
)
role = "model" if message.role == "assistant" else message.role
return {"role": role, "parts": message.content}
def _convert_content(
content: conversation.UserContent
| conversation.AssistantContent
| conversation.SystemContent,
) -> genai_types.ContentDict:
"""Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = "model" if content.role == "assistant" else content.role
return {"role": role, "parts": content.content}
# Handle the Assistant content with tool calls.
assert type(content) is conversation.AssistantContent
parts = []
if content.content:
parts.append(protos.Part(text=content.content))
if content.tool_calls:
parts.extend(
[
protos.Part(
function_call=protos.FunctionCall(
name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args),
)
)
for tool_call in content.tool_calls
]
)
return protos.Content({"role": "model", "parts": parts})
class GoogleGenerativeAIConversationEntity(
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatLog[genai_types.ContentDict],
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Call the API."""
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
options = self.entry.options
try:
await session.async_update_llm_data(
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
options.get(CONF_LLM_HASS_API),
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None
if session.llm_api:
if chat_log.llm_api:
tools = [
_format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
)
prompt, *messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
prompt = chat_log.content[0].content # type: ignore[union-attr]
messages: list[genai_types.ContentDict] = []
# Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = []
for chat_content in chat_log.content[1:]:
if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content))
continue
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
tool_results.clear()
messages.append(
_convert_content(
cast(
conversation.UserContent
| conversation.SystemContent
| conversation.AssistantContent,
chat_content,
)
)
)
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
model = genai.GenerativeModel(
model_name=model_name,
generation_config={
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
),
},
tools=tools or None,
system_instruction=prompt["parts"] if supports_system_instruction else None,
system_instruction=prompt if supports_system_instruction else None,
)
if not supports_system_instruction:
messages = [
{"role": "user", "parts": prompt["parts"]},
{"role": "user", "parts": prompt},
{"role": "model", "parts": "Ok"},
*messages,
]
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text]
)
if content:
session.async_add_message(
conversation.Content(
role="assistant",
agent_id=user_input.agent_id,
content=content,
)
)
function_calls = [
part.function_call for part in chat_response.parts if part.function_call
]
if not function_calls or not session.llm_api:
break
tool_responses = []
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_calls = []
for part in chat_response.parts:
if not part.function_call:
continue
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"])
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
function_response = await session.async_call_tool(tool_input)
tool_responses.append(
protos.Part(
function_response=protos.FunctionResponse(
name=tool_name, response=function_response
tool_calls.append(
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
)
chat_request = _create_google_tool_response_content(
[
tool_response
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_calls or None,
)
)
)
chat_request = protos.Content(parts=tool_responses)
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id,
content=chat_request,
)
]
)
if not tool_calls:
break
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text])
)
return conversation.ConversationResult(
response=response, conversation_id=session.conversation_id
response=response, conversation_id=chat_log.conversation_id
)
async def _async_entry_update_listener(

View File

@ -70,7 +70,9 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec)
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
def _convert_message_to_param(
message: ChatCompletionMessage,
) -> ChatCompletionMessageParam:
"""Convert from class to TypedDict."""
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls:
@ -94,20 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
return param
def _chat_message_convert(
message: conversation.Content
| conversation.NativeContent[ChatCompletionMessageParam],
def _convert_content_to_param(
content: conversation.Content,
) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format."""
role = message.role
if role == "native":
# mypy doesn't understand that checking role ensures content type
return message.content # type: ignore[return-value]
if role == "system":
role = "developer"
return cast(
ChatCompletionMessageParam,
{"role": role, "content": message.content},
if content.role == "tool_result":
assert type(content) is conversation.ToolResultContent
return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = content.role
if role == "system":
role = "developer"
return cast(
ChatCompletionMessageParam,
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
)
# Handle the Assistant content including tool calls.
assert type(content) is conversation.AssistantContent
return ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
tool_calls=[
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
type="function",
)
for tool_call in content.tool_calls
],
)
@ -171,14 +195,14 @@ class OpenAIConversationEntity(
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatLog[ChatCompletionMessageParam],
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Call the API."""
assert user_input.agent_id
options = self.entry.options
try:
await session.async_update_llm_data(
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
options.get(CONF_LLM_HASS_API),
@ -188,17 +212,14 @@ class OpenAIConversationEntity(
return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None
if session.llm_api:
if chat_log.llm_api:
tools = [
_format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
messages = [_convert_content_to_param(content) for content in chat_log.content]
client = self.entry.runtime_data
@ -213,7 +234,7 @@ class OpenAIConversationEntity(
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": session.conversation_id,
"user": chat_log.conversation_id,
}
if model.startswith("o"):
@ -229,43 +250,39 @@ class OpenAIConversationEntity(
LOGGER.debug("Response %s", result)
response = result.choices[0].message
messages.append(_message_convert(response))
messages.append(_convert_message_to_param(response))
session.async_add_message(
conversation.Content(
role=response.role,
agent_id=user_input.agent_id,
content=response.content or "",
),
tool_calls: list[llm.ToolInput] | None = None
if response.tool_calls:
tool_calls = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
for tool_call in response.tool_calls
]
messages.extend(
[
_convert_content_to_param(tool_response)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=response.content or "",
tool_calls=tool_calls,
)
)
]
)
if not response.tool_calls or not session.llm_api:
if not tool_calls:
break
for tool_call in response.tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
tool_response = await session.async_call_tool(tool_input)
messages.append(
ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
content=json.dumps(tool_response),
)
)
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id,
content=messages[-1],
)
)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=session.conversation_id
response=intent_response, conversation_id=chat_log.conversation_id
)
async def _async_entry_update_listener(

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import dataclass, field as dc_field
from datetime import timedelta
from decimal import Decimal
from enum import Enum
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from homeassistant.util.ulid import ulid_now
from . import (
area_registry as ar,
@ -139,6 +140,8 @@ class ToolInput:
tool_name: str
tool_args: dict[str, Any]
# Using lambda for default to allow patching in tests
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
class Tool:

View File

@ -236,6 +236,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
@ -373,6 +374,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),

View File

@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components.conversation import (
Content,
AssistantContent,
ConversationInput,
ConverseError,
NativeContent,
ToolResultContent,
async_get_chat_log,
)
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@ -56,13 +56,13 @@ async def test_cleanup(
):
conversation_id = session.conversation_id
# Add message so it persists
chat_log.async_add_message(
Content(
role="assistant",
agent_id=mock_conversation_input.agent_id,
content="",
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
)
):
pytest.fail("should not reach here")
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
@ -79,7 +79,7 @@ async def test_cleanup(
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
async def test_add_message(
async def test_default_content(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
@ -87,95 +87,11 @@ async def test_add_message(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
assert len(chat_log.messages) == 2
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(role="system", agent_id=None, content="")
)
# No 2 user messages in a row
assert chat_log.messages[1].role == "user"
with pytest.raises(ValueError):
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
# No 2 assistant messages in a row
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
assert len(chat_log.messages) == 3
assert chat_log.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(role="assistant", agent_id=None, content="")
)
async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
messages = chat_log.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
# Different agent, native messages will be filtered out.
chat_log.async_add_message(
NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
# A non-native message from another agent is not filtered out.
chat_log.async_add_message(
Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
)
)
assert len(chat_log.messages) == 6
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5
assert messages[2] == Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)
assert len(chat_log.content) == 2
assert chat_log.content[0].role == "system"
assert chat_log.content[0].content == ""
assert chat_log.content[1].role == "user"
assert chat_log.content[1].content == mock_conversation_input.text
async def test_llm_api(
@ -268,12 +184,10 @@ async def test_template_variables(
),
)
assert chat_log.user_name == "Test User"
assert "The instance name is test home." in chat_log.messages[0].content
assert "The user name is Test User." in chat_log.messages[0].content
assert "The user id is 12345." in chat_log.messages[0].content
assert "The calling platform is test." in chat_log.messages[0].content
assert "The instance name is test home." in chat_log.content[0].content
assert "The user name is Test User." in chat_log.content[0].content
assert "The user id is 12345." in chat_log.content[0].content
assert "The calling platform is test." in chat_log.content[0].content
async def test_extra_systen_prompt(
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_log.async_add_message(
Content(
role="assistant",
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
)
):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt)
assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that follow-up conversations with no system prompt take previous one
conversation_id = chat_log.conversation_id
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
)
assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt)
assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_log.async_add_message(
Content(
role="assistant",
async for _tool_result in chat_log.async_add_assistant_content(
AssistantContent(
agent_id="mock-agent-id",
content="Hey!",
)
)
):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.messages[0].content
assert chat_log.content[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.content[0].content
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
)
assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
assert chat_log.content[0].content.endswith(extra_system_prompt2)
async def test_tool_call(
@ -383,8 +297,7 @@ async def test_tool_call(
mock_tool.async_call.return_value = "Test response"
with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
@ -398,14 +311,29 @@ async def test_tool_call(
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_log.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
)
):
assert result is None
result = tool_result_content
assert result == "Test response"
assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result="Test response",
tool_name="test_tool",
)
async def test_tool_call_exception(
@ -423,8 +351,7 @@ async def test_tool_call_exception(
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
@ -438,11 +365,26 @@ async def test_tool_call_exception(
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_log.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
],
)
)
):
assert result is None
result = tool_result_content
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
tool_name="test_tool",
)

View File

@ -36,6 +36,13 @@ def freeze_the_time():
yield
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"]
)
@ -177,6 +184,7 @@ async def test_chat_history(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
@ -256,6 +264,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={
"param1": ["test_value", "param1's value"],
@ -287,9 +296,7 @@ async def test_function_call(
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [
p.function_response.name
for p in detail_event["data"]["messages"][2]["content"].parts
if p.function_response
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
] == ["test_tool"]
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={},
),
@ -451,6 +459,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={"param1": 1},
),
@ -605,6 +614,7 @@ async def test_template_variables(
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.text = "Model response"
mock_part.function_call = None
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id

View File

@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat(
hass: HomeAssistant,
@ -205,6 +212,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args=expected_tool_args,
),
@ -285,6 +293,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),

View File

@ -195,6 +195,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
@ -359,6 +360,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool",
tool_args={"param1": "test_value"},
),