Chat session rev2 (#137209)

* Chat Session rev 2

* Rename session to chat_log

* Simplify typing

* Typing

* Address comments

* Fix anthropic and ollama
This commit is contained in:
Paulus Schoutsen 2025-02-03 00:05:20 -05:00 committed by GitHub
parent ce93cb9467
commit 9679fc7878
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 388 additions and 330 deletions

View File

@ -272,6 +272,7 @@ class AnthropicConversationEntity(
continue continue
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.name, tool_name=tool_call.name,
tool_args=cast(dict[str, Any], tool_call.input), tool_args=cast(dict[str, Any], tool_call.input),
) )

View File

@ -1063,11 +1063,11 @@ class PipelineRun:
agent_id=self.intent_agent, agent_id=self.intent_agent,
extra_system_prompt=conversation_extra_system_prompt, extra_system_prompt=conversation_extra_system_prompt,
) )
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
agent_id = user_input.agent_id agent_id = self.intent_agent
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
intent_response: intent.IntentResponse | None = None intent_response: intent.IntentResponse | None = None
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT: if not processed_locally:
# Sentence triggers override conversation agent # Sentence triggers override conversation agent
if ( if (
trigger_response_text trigger_response_text
@ -1105,13 +1105,13 @@ class PipelineRun:
speech: str = intent_response.speech.get("plain", {}).get( speech: str = intent_response.speech.get("plain", {}).get(
"speech", "" "speech", ""
) )
chat_log.async_add_message( async for _ in chat_log.async_add_assistant_content(
conversation.Content( conversation.AssistantContent(
role="assistant",
agent_id=agent_id, agent_id=agent_id,
content=speech, content=speech,
) )
) ):
pass
conversation_result = conversation.ConversationResult( conversation_result = conversation.ConversationResult(
response=intent_response, response=intent_response,
conversation_id=session.conversation_id, conversation_id=session.conversation_id,

View File

@ -30,6 +30,16 @@ from .agent_manager import (
async_get_agent, async_get_agent,
get_agent_manager, get_agent_manager,
) )
from .chat_log import (
AssistantContent,
ChatLog,
Content,
ConverseError,
SystemContent,
ToolResultContent,
UserContent,
async_get_chat_log,
)
from .const import ( from .const import (
ATTR_AGENT_ID, ATTR_AGENT_ID,
ATTR_CONVERSATION_ID, ATTR_CONVERSATION_ID,
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"HOME_ASSISTANT_AGENT", "HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT",
"AssistantContent",
"ChatLog", "ChatLog",
"Content", "Content",
"ConversationEntity", "ConversationEntity",
@ -63,7 +73,9 @@ __all__ = [
"ConversationResult", "ConversationResult",
"ConversationTraceEventType", "ConversationTraceEventType",
"ConverseError", "ConverseError",
"NativeContent", "SystemContent",
"ToolResultContent",
"UserContent",
"async_conversation_trace_append", "async_conversation_trace_append",
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",

View File

@ -2,19 +2,16 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator from collections.abc import AsyncGenerator, Generator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from datetime import datetime
import logging import logging
from typing import Literal
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
def async_get_chat_log( def async_get_chat_log(
hass: HomeAssistant, hass: HomeAssistant,
session: chat_session.ChatSession, session: chat_session.ChatSession,
user_input: ConversationInput, user_input: ConversationInput | None = None,
) -> Generator[ChatLog]: ) -> Generator[ChatLog]:
"""Return chat log for a specific chat session.""" """Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY) all_history = hass.data.get(DATA_CHAT_HISTORY)
@ -42,9 +39,9 @@ def async_get_chat_log(
history = all_history.get(session.conversation_id) history = all_history.get(session.conversation_id)
if history: if history:
history = replace(history, messages=history.messages.copy()) history = replace(history, content=history.content.copy())
else: else:
history = ChatLog(hass, session.conversation_id, user_input.agent_id) history = ChatLog(hass, session.conversation_id)
@callback @callback
def do_cleanup() -> None: def do_cleanup() -> None:
@ -53,22 +50,19 @@ def async_get_chat_log(
session.async_on_cleanup(do_cleanup) session.async_on_cleanup(do_cleanup)
message: Content = Content( if user_input is not None:
role="user", history.async_add_user_content(UserContent(content=user_input.text))
agent_id=user_input.agent_id,
content=user_input.text, last_message = history.content[-1]
)
history.async_add_message(message)
yield history yield history
if history.messages[-1] is message: if history.content[-1] is last_message:
LOGGER.debug( LOGGER.debug(
"History opened but no assistant message was added, ignoring update" "History opened but no assistant message was added, ignoring update"
) )
return return
history.last_updated = dt_util.utcnow()
all_history[session.conversation_id] = history all_history[session.conversation_id] = history
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
) )
@dataclass @dataclass(frozen=True)
class Content: class SystemContent:
"""Base class for chat messages.""" """Base class for chat messages."""
role: Literal["system", "assistant", "user"] role: str = field(init=False, default="system")
agent_id: str | None
content: str content: str
@dataclass(frozen=True) @dataclass(frozen=True)
class NativeContent[_NativeT]: class UserContent:
"""Native content.""" """Assistant content."""
role: str = field(init=False, default="native") role: str = field(init=False, default="user")
content: str
@dataclass(frozen=True)
class AssistantContent:
"""Assistant content."""
role: str = field(init=False, default="assistant")
agent_id: str agent_id: str
content: _NativeT content: str
tool_calls: list[llm.ToolInput] | None = None
@dataclass(frozen=True)
class ToolResultContent:
"""Tool result content."""
role: str = field(init=False, default="tool_result")
agent_id: str
tool_call_id: str
tool_name: str
tool_result: JsonObjectType
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@dataclass @dataclass
class ChatLog[_NativeT]: class ChatLog:
"""Class holding the chat history of a specific conversation.""" """Class holding the chat history of a specific conversation."""
hass: HomeAssistant hass: HomeAssistant
conversation_id: str conversation_id: str
agent_id: str | None content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
user_name: str | None = None
messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
)
extra_system_prompt: str | None = None extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback @callback
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None: def async_add_user_content(self, content: UserContent) -> None:
"""Process intent.""" """Add user content to the log."""
if message.role == "system": self.content.append(content)
raise ValueError("Cannot add system messages to history")
if message.role != "native" and self.messages[-1].role == message.role:
raise ValueError("Cannot add two assistant or user messages in a row")
self.messages.append(message) async def async_add_assistant_content(
self, content: AssistantContent
) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content."""
self.content.append(content)
@callback if content.tool_calls is None:
def async_get_messages( return
self, agent_id: str | None = None
) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs. if self.llm_api is None:
It can still include assistant/user messages generated by other agents. raise ValueError("No LLM API configured")
"""
return [ for tool_input in content.tool_calls:
message LOGGER.debug(
for message in self.messages "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
if message.role != "native" or message.agent_id == agent_id )
]
try:
tool_result = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__}
if str(e):
tool_result["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_result)
response_content = ToolResultContent(
agent_id=content.agent_id,
tool_call_id=tool_input.id,
tool_name=tool_input.tool_name,
tool_result=tool_result,
)
self.content.append(response_content)
yield response_content
async def async_update_llm_data( async def async_update_llm_data(
self, self,
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
self.llm_api = llm_api self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt self.extra_system_prompt = extra_system_prompt
self.messages[0] = Content( self.content[0] = SystemContent(content=prompt)
role="system",
agent_id=user_input.agent_id,
content=prompt,
)
LOGGER.debug("Prompt: %s", self.messages) LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, trace.ConversationTraceEventType.AGENT_DETAIL,
{ {
"messages": self.messages, "messages": self.content,
"tools": self.llm_api.tools if self.llm_api else None, "tools": self.llm_api.tools if self.llm_api else None,
}, },
) )
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
"""Invoke LLM tool for the configured LLM API."""
if not self.llm_api:
raise ValueError("No LLM API configured")
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
try:
tool_response = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
return tool_response

View File

@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from .chat_log import AssistantContent, async_get_chat_log
from .const import ( from .const import (
DATA_DEFAULT_ENTITY, DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES, DEFAULT_EXPOSED_ATTRIBUTES,
@ -63,7 +64,6 @@ from .const import (
) )
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
from .session import Content, async_get_chat_log
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
) )
speech: str = response.speech.get("plain", {}).get("speech", "") speech: str = response.speech.get("plain", {}).get("speech", "")
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant", agent_id=user_input.agent_id, # type: ignore[arg-type]
agent_id=user_input.agent_id,
content=speech, content=speech,
) )
) ):
pass
return ConversationResult( return ConversationResult(
response=response, conversation_id=session.conversation_id response=response, conversation_id=session.conversation_id

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs import codecs
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Literal from typing import Any, Literal, cast
from google.api_core.exceptions import GoogleAPIError from google.api_core.exceptions import GoogleAPIError
import google.generativeai as genai import google.generativeai as genai
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
return value return value
def _chat_message_convert( def _create_google_tool_response_content(
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict], content: list[conversation.ToolResultContent],
) -> genai_types.ContentDict: ) -> protos.Content:
"""Convert any native chat message for this agent to the native format.""" """Create a Google tool response content."""
if message.role == "native": return protos.Content(
return message.content parts=[
protos.Part(
function_response=protos.FunctionResponse(
name=tool_result.tool_name, response=tool_result.tool_result
)
)
for tool_result in content
]
)
role = "model" if message.role == "assistant" else message.role
return {"role": role, "parts": message.content} def _convert_content(
content: conversation.UserContent
| conversation.AssistantContent
| conversation.SystemContent,
) -> genai_types.ContentDict:
"""Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = "model" if content.role == "assistant" else content.role
return {"role": role, "parts": content.content}
# Handle the Assistant content with tool calls.
assert type(content) is conversation.AssistantContent
parts = []
if content.content:
parts.append(protos.Part(text=content.content))
if content.tool_calls:
parts.extend(
[
protos.Part(
function_call=protos.FunctionCall(
name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args),
)
)
for tool_call in content.tool_calls
]
)
return protos.Content({"role": "model", "parts": parts})
class GoogleGenerativeAIConversationEntity( class GoogleGenerativeAIConversationEntity(
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
async def _async_handle_message( async def _async_handle_message(
self, self,
user_input: conversation.ConversationInput, user_input: conversation.ConversationInput,
session: conversation.ChatLog[genai_types.ContentDict], chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
options = self.entry.options options = self.entry.options
try: try:
await session.async_update_llm_data( await chat_log.async_update_llm_data(
DOMAIN, DOMAIN,
user_input, user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
return err.as_conversation_result() return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if session.llm_api: if chat_log.llm_api:
tools = [ tools = [
_format_tool(tool, session.llm_api.custom_serializer) _format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in session.llm_api.tools for tool in chat_log.llm_api.tools
] ]
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name "gemini-1.0" not in model_name and "gemini-pro" not in model_name
) )
prompt, *messages = [ prompt = chat_log.content[0].content # type: ignore[union-attr]
_chat_message_convert(message) for message in session.async_get_messages() messages: list[genai_types.ContentDict] = []
]
# Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = []
for chat_content in chat_log.content[1:]:
if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content))
continue
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
tool_results.clear()
messages.append(
_convert_content(
cast(
conversation.UserContent
| conversation.SystemContent
| conversation.AssistantContent,
chat_content,
)
)
)
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=model_name, model_name=model_name,
generation_config={ generation_config={
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
), ),
}, },
tools=tools or None, tools=tools or None,
system_instruction=prompt["parts"] if supports_system_instruction else None, system_instruction=prompt if supports_system_instruction else None,
) )
if not supports_system_instruction: if not supports_system_instruction:
messages = [ messages = [
{"role": "user", "parts": prompt["parts"]}, {"role": "user", "parts": prompt},
{"role": "model", "parts": "Ok"}, {"role": "model", "parts": "Ok"},
*messages, *messages,
] ]
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
content = " ".join( content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text] [part.text.strip() for part in chat_response.parts if part.text]
) )
if content:
session.async_add_message(
conversation.Content(
role="assistant",
agent_id=user_input.agent_id,
content=content,
)
)
function_calls = [ tool_calls = []
part.function_call for part in chat_response.parts if part.function_call for part in chat_response.parts:
] if not part.function_call:
continue
if not function_calls or not session.llm_api: tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
break
tool_responses = []
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"] tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"]) tool_args = _escape_decode(tool_call["args"])
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) tool_calls.append(
function_response = await session.async_call_tool(tool_input) llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
tool_responses.append(
protos.Part(
function_response=protos.FunctionResponse(
name=tool_name, response=function_response
) )
)
) chat_request = _create_google_tool_response_content(
chat_request = protos.Content(parts=tool_responses) [
session.async_add_message( tool_response
conversation.NativeContent( async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=chat_request, content=content,
tool_calls=tool_calls or None,
) )
) )
]
)
if not tool_calls:
break
response = intent.IntentResponse(language=user_input.language) response = intent.IntentResponse(language=user_input.language)
response.async_set_speech( response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text]) " ".join([part.text.strip() for part in chat_response.parts if part.text])
) )
return conversation.ConversationResult( return conversation.ConversationResult(
response=response, conversation_id=session.conversation_id response=response, conversation_id=chat_log.conversation_id
) )
async def _async_entry_update_listener( async def _async_entry_update_listener(

View File

@ -70,7 +70,9 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec) return ChatCompletionToolParam(type="function", function=tool_spec)
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam: def _convert_message_to_param(
message: ChatCompletionMessage,
) -> ChatCompletionMessageParam:
"""Convert from class to TypedDict.""" """Convert from class to TypedDict."""
tool_calls: list[ChatCompletionMessageToolCallParam] = [] tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls: if message.tool_calls:
@ -94,20 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
return param return param
def _chat_message_convert( def _convert_content_to_param(
message: conversation.Content content: conversation.Content,
| conversation.NativeContent[ChatCompletionMessageParam],
) -> ChatCompletionMessageParam: ) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format.""" """Convert any native chat message for this agent to the native format."""
role = message.role if content.role == "tool_result":
if role == "native": assert type(content) is conversation.ToolResultContent
# mypy doesn't understand that checking role ensures content type return ChatCompletionToolMessageParam(
return message.content # type: ignore[return-value] role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = content.role
if role == "system": if role == "system":
role = "developer" role = "developer"
return cast( return cast(
ChatCompletionMessageParam, ChatCompletionMessageParam,
{"role": role, "content": message.content}, {"role": content.role, "content": content.content}, # type: ignore[union-attr]
)
# Handle the Assistant content including tool calls.
assert type(content) is conversation.AssistantContent
return ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
tool_calls=[
ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
type="function",
)
for tool_call in content.tool_calls
],
) )
@ -171,14 +195,14 @@ class OpenAIConversationEntity(
async def _async_handle_message( async def _async_handle_message(
self, self,
user_input: conversation.ConversationInput, user_input: conversation.ConversationInput,
session: conversation.ChatLog[ChatCompletionMessageParam], chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
assert user_input.agent_id assert user_input.agent_id
options = self.entry.options options = self.entry.options
try: try:
await session.async_update_llm_data( await chat_log.async_update_llm_data(
DOMAIN, DOMAIN,
user_input, user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
@ -188,17 +212,14 @@ class OpenAIConversationEntity(
return err.as_conversation_result() return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None tools: list[ChatCompletionToolParam] | None = None
if session.llm_api: if chat_log.llm_api:
tools = [ tools = [
_format_tool(tool, session.llm_api.custom_serializer) _format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in session.llm_api.tools for tool in chat_log.llm_api.tools
] ]
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [_convert_content_to_param(content) for content in chat_log.content]
messages = [
_chat_message_convert(message) for message in session.async_get_messages()
]
client = self.entry.runtime_data client = self.entry.runtime_data
@ -213,7 +234,7 @@ class OpenAIConversationEntity(
), ),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": session.conversation_id, "user": chat_log.conversation_id,
} }
if model.startswith("o"): if model.startswith("o"):
@ -229,43 +250,39 @@ class OpenAIConversationEntity(
LOGGER.debug("Response %s", result) LOGGER.debug("Response %s", result)
response = result.choices[0].message response = result.choices[0].message
messages.append(_message_convert(response)) messages.append(_convert_message_to_param(response))
session.async_add_message( tool_calls: list[llm.ToolInput] | None = None
conversation.Content( if response.tool_calls:
role=response.role, tool_calls = [
agent_id=user_input.agent_id, llm.ToolInput(
content=response.content or "", id=tool_call.id,
),
)
if not response.tool_calls or not session.llm_api:
break
for tool_call in response.tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments), tool_args=json.loads(tool_call.function.arguments),
) )
tool_response = await session.async_call_tool(tool_input) for tool_call in response.tool_calls
messages.append( ]
ChatCompletionToolMessageParam(
role="tool", messages.extend(
tool_call_id=tool_call.id, [
content=json.dumps(tool_response), _convert_content_to_param(tool_response)
) async for tool_response in chat_log.async_add_assistant_content(
) conversation.AssistantContent(
session.async_add_message(
conversation.NativeContent(
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=messages[-1], content=response.content or "",
tool_calls=tool_calls,
) )
) )
]
)
if not tool_calls:
break
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "") intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=session.conversation_id response=intent_response, conversation_id=chat_log.conversation_id
) )
async def _async_entry_update_listener( async def _async_entry_update_listener(

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass, field as dc_field
from datetime import timedelta from datetime import timedelta
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util, yaml as yaml_util from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from homeassistant.util.ulid import ulid_now
from . import ( from . import (
area_registry as ar, area_registry as ar,
@ -139,6 +140,8 @@ class ToolInput:
tool_name: str tool_name: str
tool_args: dict[str, Any] tool_args: dict[str, Any]
# Using lambda for default to allow patching in tests
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
class Tool: class Tool:

View File

@ -236,6 +236,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),
@ -373,6 +374,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="toolu_0123456789AbCdEfGhIjKlM",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),

View File

@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components.conversation import ( from homeassistant.components.conversation import (
Content, AssistantContent,
ConversationInput, ConversationInput,
ConverseError, ConverseError,
NativeContent, ToolResultContent,
async_get_chat_log, async_get_chat_log,
) )
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm from homeassistant.helpers import chat_session, llm
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
@pytest.fixture @pytest.fixture
def mock_ulid() -> Generator[Mock]: def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library.""" """Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid" mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now yield mock_ulid_now
@ -56,13 +56,13 @@ async def test_cleanup(
): ):
conversation_id = session.conversation_id conversation_id = session.conversation_id
# Add message so it persists # Add message so it persists
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant", agent_id="mock-agent-id",
agent_id=mock_conversation_input.agent_id, content="Hey!",
content="",
)
) )
):
pytest.fail("should not reach here")
assert conversation_id in hass.data[DATA_CHAT_HISTORY] assert conversation_id in hass.data[DATA_CHAT_HISTORY]
@ -79,7 +79,7 @@ async def test_cleanup(
assert conversation_id not in hass.data[DATA_CHAT_HISTORY] assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
async def test_add_message( async def test_default_content(
hass: HomeAssistant, mock_conversation_input: ConversationInput hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None: ) -> None:
"""Test filtering of messages.""" """Test filtering of messages."""
@ -87,95 +87,11 @@ async def test_add_message(
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
assert len(chat_log.messages) == 2 assert len(chat_log.content) == 2
assert chat_log.content[0].role == "system"
with pytest.raises(ValueError): assert chat_log.content[0].content == ""
chat_log.async_add_message( assert chat_log.content[1].role == "user"
Content(role="system", agent_id=None, content="") assert chat_log.content[1].content == mock_conversation_input.text
)
# No 2 user messages in a row
assert chat_log.messages[1].role == "user"
with pytest.raises(ValueError):
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
# No 2 assistant messages in a row
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
assert len(chat_log.messages) == 3
assert chat_log.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(role="assistant", agent_id=None, content="")
)
async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
messages = chat_log.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == Content(
role="system",
agent_id=None,
content="",
)
assert messages[1] == Content(
role="user",
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_log.async_add_message(
Content(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_log.async_add_message(
Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
# Different agent, native messages will be filtered out.
chat_log.async_add_message(
NativeContent(agent_id="another-mock-agent-id", content=1)
)
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
# A non-native message from another agent is not filtered out.
chat_log.async_add_message(
Content(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
)
)
assert len(chat_log.messages) == 6
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5
assert messages[2] == Content(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
assert messages[4] == Content(
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
)
async def test_llm_api( async def test_llm_api(
@ -268,12 +184,10 @@ async def test_template_variables(
), ),
) )
assert chat_log.user_name == "Test User" assert "The instance name is test home." in chat_log.content[0].content
assert "The user name is Test User." in chat_log.content[0].content
assert "The instance name is test home." in chat_log.messages[0].content assert "The user id is 12345." in chat_log.content[0].content
assert "The user name is Test User." in chat_log.messages[0].content assert "The calling platform is test." in chat_log.content[0].content
assert "The user id is 12345." in chat_log.messages[0].content
assert "The calling platform is test." in chat_log.messages[0].content
async def test_extra_systen_prompt( async def test_extra_systen_prompt(
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
) )
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
) ):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt) assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that follow-up conversations with no system prompt take previous one # Verify that follow-up conversations with no system prompt take previous one
conversation_id = chat_log.conversation_id conversation_id = chat_log.conversation_id
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
) )
assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.extra_system_prompt == extra_system_prompt
assert chat_log.messages[0].content.endswith(extra_system_prompt) assert chat_log.content[0].content.endswith(extra_system_prompt)
# Verify that we take new system prompts # Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2 mock_conversation_input.extra_system_prompt = extra_system_prompt2
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
) )
chat_log.async_add_message( async for _tool_result in chat_log.async_add_assistant_content(
Content( AssistantContent(
role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
) )
) ):
pytest.fail("should not reach here")
assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2) assert chat_log.content[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_log.messages[0].content assert extra_system_prompt not in chat_log.content[0].content
# Verify that follow-up conversations with no system prompt take previous one # Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None mock_conversation_input.extra_system_prompt = None
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
) )
assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.extra_system_prompt == extra_system_prompt2
assert chat_log.messages[0].content.endswith(extra_system_prompt2) assert chat_log.content[0].content.endswith(extra_system_prompt2)
async def test_tool_call( async def test_tool_call(
@ -383,8 +297,7 @@ async def test_tool_call(
mock_tool.async_call.return_value = "Test response" mock_tool.async_call.return_value = "Test response"
with patch( with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
return_value=[],
) as mock_get_tools: ) as mock_get_tools:
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
@ -398,14 +311,29 @@ async def test_tool_call(
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
result = await chat_log.async_call_tool( result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput( llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "Test Param"}, tool_args={"param1": "Test Param"},
) )
],
) )
):
assert result is None
result = tool_result_content
assert result == "Test response" assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result="Test response",
tool_name="test_tool",
)
async def test_tool_call_exception( async def test_tool_call_exception(
@ -423,8 +351,7 @@ async def test_tool_call_exception(
mock_tool.async_call.side_effect = HomeAssistantError("Test error") mock_tool.async_call.side_effect = HomeAssistantError("Test error")
with patch( with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools", "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
return_value=[],
) as mock_get_tools: ) as mock_get_tools:
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
@ -438,11 +365,26 @@ async def test_tool_call_exception(
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
result = await chat_log.async_call_tool( result = None
async for tool_result_content in chat_log.async_add_assistant_content(
AssistantContent(
agent_id=mock_conversation_input.agent_id,
content="",
tool_calls=[
llm.ToolInput( llm.ToolInput(
id="mock-tool-call-id",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "Test Param"}, tool_args={"param1": "Test Param"},
) )
],
) )
):
assert result is None
result = tool_result_content
assert result == {"error": "HomeAssistantError", "error_text": "Test error"} assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
tool_name="test_tool",
)

View File

@ -36,6 +36,13 @@ def freeze_the_time():
yield yield
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize( @pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"] "agent_id", [None, "conversation.google_generative_ai_conversation"]
) )
@ -177,6 +184,7 @@ async def test_chat_history(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
) )
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_function_call( async def test_function_call(
mock_get_tools, mock_get_tools,
hass: HomeAssistant, hass: HomeAssistant,
@ -256,6 +264,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={ tool_args={
"param1": ["test_value", "param1's value"], "param1": ["test_value", "param1's value"],
@ -287,9 +296,7 @@ async def test_function_call(
detail_event = trace_events[1] detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [ assert [
p.function_response.name p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
for p in detail_event["data"]["messages"][2]["content"].parts
if p.function_response
] == ["test_tool"] ] == ["test_tool"]
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={}, tool_args={},
), ),
@ -451,6 +459,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": 1}, tool_args={"param1": 1},
), ),
@ -605,6 +614,7 @@ async def test_template_variables(
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.text = "Model response" mock_part.text = "Model response"
mock_part.function_call = None
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id hass, "hello", None, context, agent_id=mock_config_entry.entry_id

View File

@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture(autouse=True)
def mock_ulid_tools():
"""Mock generated ULIDs for tool calls."""
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
yield
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat( async def test_chat(
hass: HomeAssistant, hass: HomeAssistant,
@ -205,6 +212,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args=expected_tool_args, tool_args=expected_tool_args,
), ),
@ -285,6 +293,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),

View File

@ -195,6 +195,7 @@ async def test_function_call(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),
@ -359,6 +360,7 @@ async def test_function_exception(
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
llm.ToolInput( llm.ToolInput(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),