mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 13:47:35 +00:00
Chat session rev2 (#137209)
* Chat Session rev 2 * Rename session to chat_log * Simplify typing * Typing * Address comments * Fix anthropic and ollama
This commit is contained in:
parent
ce93cb9467
commit
9679fc7878
@ -272,6 +272,7 @@ class AnthropicConversationEntity(
|
||||
continue
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool_args=cast(dict[str, Any], tool_call.input),
|
||||
)
|
||||
|
@ -1063,11 +1063,11 @@ class PipelineRun:
|
||||
agent_id=self.intent_agent,
|
||||
extra_system_prompt=conversation_extra_system_prompt,
|
||||
)
|
||||
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
agent_id = user_input.agent_id
|
||||
agent_id = self.intent_agent
|
||||
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
||||
intent_response: intent.IntentResponse | None = None
|
||||
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
|
||||
if not processed_locally:
|
||||
# Sentence triggers override conversation agent
|
||||
if (
|
||||
trigger_response_text
|
||||
@ -1105,13 +1105,13 @@ class PipelineRun:
|
||||
speech: str = intent_response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
chat_log.async_add_message(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
async for _ in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=speech,
|
||||
)
|
||||
)
|
||||
):
|
||||
pass
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=session.conversation_id,
|
||||
|
@ -30,6 +30,16 @@ from .agent_manager import (
|
||||
async_get_agent,
|
||||
get_agent_manager,
|
||||
)
|
||||
from .chat_log import (
|
||||
AssistantContent,
|
||||
ChatLog,
|
||||
Content,
|
||||
ConverseError,
|
||||
SystemContent,
|
||||
ToolResultContent,
|
||||
UserContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from .const import (
|
||||
ATTR_AGENT_ID,
|
||||
ATTR_CONVERSATION_ID,
|
||||
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"AssistantContent",
|
||||
"ChatLog",
|
||||
"Content",
|
||||
"ConversationEntity",
|
||||
@ -63,7 +73,9 @@ __all__ = [
|
||||
"ConversationResult",
|
||||
"ConversationTraceEventType",
|
||||
"ConverseError",
|
||||
"NativeContent",
|
||||
"SystemContent",
|
||||
"ToolResultContent",
|
||||
"UserContent",
|
||||
"async_conversation_trace_append",
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
|
@ -2,19 +2,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import chat_session, intent, llm, template
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
|
||||
def async_get_chat_log(
|
||||
hass: HomeAssistant,
|
||||
session: chat_session.ChatSession,
|
||||
user_input: ConversationInput,
|
||||
user_input: ConversationInput | None = None,
|
||||
) -> Generator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||
@ -42,9 +39,9 @@ def async_get_chat_log(
|
||||
history = all_history.get(session.conversation_id)
|
||||
|
||||
if history:
|
||||
history = replace(history, messages=history.messages.copy())
|
||||
history = replace(history, content=history.content.copy())
|
||||
else:
|
||||
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
|
||||
history = ChatLog(hass, session.conversation_id)
|
||||
|
||||
@callback
|
||||
def do_cleanup() -> None:
|
||||
@ -53,22 +50,19 @@ def async_get_chat_log(
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
message: Content = Content(
|
||||
role="user",
|
||||
agent_id=user_input.agent_id,
|
||||
content=user_input.text,
|
||||
)
|
||||
history.async_add_message(message)
|
||||
if user_input is not None:
|
||||
history.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
last_message = history.content[-1]
|
||||
|
||||
yield history
|
||||
|
||||
if history.messages[-1] is message:
|
||||
if history.content[-1] is last_message:
|
||||
LOGGER.debug(
|
||||
"History opened but no assistant message was added, ignoring update"
|
||||
)
|
||||
return
|
||||
|
||||
history.last_updated = dt_util.utcnow()
|
||||
all_history[session.conversation_id] = history
|
||||
|
||||
|
||||
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Content:
|
||||
@dataclass(frozen=True)
|
||||
class SystemContent:
|
||||
"""Base class for chat messages."""
|
||||
|
||||
role: Literal["system", "assistant", "user"]
|
||||
agent_id: str | None
|
||||
role: str = field(init=False, default="system")
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeContent[_NativeT]:
|
||||
"""Native content."""
|
||||
class UserContent:
|
||||
"""Assistant content."""
|
||||
|
||||
role: str = field(init=False, default="native")
|
||||
role: str = field(init=False, default="user")
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssistantContent:
|
||||
"""Assistant content."""
|
||||
|
||||
role: str = field(init=False, default="assistant")
|
||||
agent_id: str
|
||||
content: _NativeT
|
||||
content: str
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultContent:
|
||||
"""Tool result content."""
|
||||
|
||||
role: str = field(init=False, default="tool_result")
|
||||
agent_id: str
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
tool_result: JsonObjectType
|
||||
|
||||
|
||||
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatLog[_NativeT]:
|
||||
class ChatLog:
|
||||
"""Class holding the chat history of a specific conversation."""
|
||||
|
||||
hass: HomeAssistant
|
||||
conversation_id: str
|
||||
agent_id: str | None
|
||||
user_name: str | None = None
|
||||
messages: list[Content | NativeContent[_NativeT]] = field(
|
||||
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
|
||||
)
|
||||
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||
|
||||
@callback
|
||||
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
|
||||
"""Process intent."""
|
||||
if message.role == "system":
|
||||
raise ValueError("Cannot add system messages to history")
|
||||
if message.role != "native" and self.messages[-1].role == message.role:
|
||||
raise ValueError("Cannot add two assistant or user messages in a row")
|
||||
def async_add_user_content(self, content: UserContent) -> None:
|
||||
"""Add user content to the log."""
|
||||
self.content.append(content)
|
||||
|
||||
self.messages.append(message)
|
||||
async def async_add_assistant_content(
|
||||
self, content: AssistantContent
|
||||
) -> AsyncGenerator[ToolResultContent]:
|
||||
"""Add assistant content."""
|
||||
self.content.append(content)
|
||||
|
||||
@callback
|
||||
def async_get_messages(
|
||||
self, agent_id: str | None = None
|
||||
) -> list[Content | NativeContent[_NativeT]]:
|
||||
"""Get messages for a specific agent ID.
|
||||
if content.tool_calls is None:
|
||||
return
|
||||
|
||||
This will filter out any native message tied to other agent IDs.
|
||||
It can still include assistant/user messages generated by other agents.
|
||||
"""
|
||||
return [
|
||||
message
|
||||
for message in self.messages
|
||||
if message.role != "native" or message.agent_id == agent_id
|
||||
]
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
|
||||
for tool_input in content.tool_calls:
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await self.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_result = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_result["error_text"] = str(e)
|
||||
LOGGER.debug("Tool response: %s", tool_result)
|
||||
|
||||
response_content = ToolResultContent(
|
||||
agent_id=content.agent_id,
|
||||
tool_call_id=tool_input.id,
|
||||
tool_name=tool_input.tool_name,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
self.content.append(response_content)
|
||||
yield response_content
|
||||
|
||||
async def async_update_llm_data(
|
||||
self,
|
||||
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
self.llm_api = llm_api
|
||||
self.user_name = user_name
|
||||
self.extra_system_prompt = extra_system_prompt
|
||||
self.messages[0] = Content(
|
||||
role="system",
|
||||
agent_id=user_input.agent_id,
|
||||
content=prompt,
|
||||
)
|
||||
self.content[0] = SystemContent(content=prompt)
|
||||
|
||||
LOGGER.debug("Prompt: %s", self.messages)
|
||||
LOGGER.debug("Prompt: %s", self.content)
|
||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{
|
||||
"messages": self.messages,
|
||||
"messages": self.content,
|
||||
"tools": self.llm_api.tools if self.llm_api else None,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
|
||||
"""Invoke LLM tool for the configured LLM API."""
|
||||
if not self.llm_api:
|
||||
raise ValueError("No LLM API configured")
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
||||
|
||||
try:
|
||||
tool_response = await self.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
LOGGER.debug("Tool response: %s", tool_response)
|
||||
return tool_response
|
@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_track_state_added_domain
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
|
||||
from .chat_log import AssistantContent, async_get_chat_log
|
||||
from .const import (
|
||||
DATA_DEFAULT_ENTITY,
|
||||
DEFAULT_EXPOSED_ATTRIBUTES,
|
||||
@ -63,7 +64,6 @@ from .const import (
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .session import Content, async_get_chat_log
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
|
||||
)
|
||||
|
||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=user_input.agent_id, # type: ignore[arg-type]
|
||||
content=speech,
|
||||
)
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
return ConversationResult(
|
||||
response=response, conversation_id=session.conversation_id
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from google.api_core.exceptions import GoogleAPIError
|
||||
import google.generativeai as genai
|
||||
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
|
||||
) -> genai_types.ContentDict:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if message.role == "native":
|
||||
return message.content
|
||||
def _create_google_tool_response_content(
|
||||
content: list[conversation.ToolResultContent],
|
||||
) -> protos.Content:
|
||||
"""Create a Google tool response content."""
|
||||
return protos.Content(
|
||||
parts=[
|
||||
protos.Part(
|
||||
function_response=protos.FunctionResponse(
|
||||
name=tool_result.tool_name, response=tool_result.tool_result
|
||||
)
|
||||
)
|
||||
for tool_result in content
|
||||
]
|
||||
)
|
||||
|
||||
role = "model" if message.role == "assistant" else message.role
|
||||
return {"role": role, "parts": message.content}
|
||||
|
||||
def _convert_content(
|
||||
content: conversation.UserContent
|
||||
| conversation.AssistantContent
|
||||
| conversation.SystemContent,
|
||||
) -> genai_types.ContentDict:
|
||||
"""Convert HA content to Google content."""
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
role = "model" if content.role == "assistant" else content.role
|
||||
return {"role": role, "parts": content.content}
|
||||
|
||||
# Handle the Assistant content with tool calls.
|
||||
assert type(content) is conversation.AssistantContent
|
||||
parts = []
|
||||
|
||||
if content.content:
|
||||
parts.append(protos.Part(text=content.content))
|
||||
|
||||
if content.tool_calls:
|
||||
parts.extend(
|
||||
[
|
||||
protos.Part(
|
||||
function_call=protos.FunctionCall(
|
||||
name=tool_call.tool_name,
|
||||
args=_escape_decode(tool_call.tool_args),
|
||||
)
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
return protos.Content({"role": "model", "parts": parts})
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatLog[genai_types.ContentDict],
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
|
||||
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await session.async_update_llm_data(
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if session.llm_api:
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, session.llm_api.custom_serializer)
|
||||
for tool in session.llm_api.tools
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
|
||||
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||
)
|
||||
|
||||
prompt, *messages = [
|
||||
_chat_message_convert(message) for message in session.async_get_messages()
|
||||
]
|
||||
prompt = chat_log.content[0].content # type: ignore[union-attr]
|
||||
messages: list[genai_types.ContentDict] = []
|
||||
|
||||
# Google groups tool results, we do not. Group them before sending.
|
||||
tool_results: list[conversation.ToolResultContent] = []
|
||||
|
||||
for chat_content in chat_log.content[1:]:
|
||||
if chat_content.role == "tool_result":
|
||||
# mypy doesn't like picking a type based on checking shared property 'role'
|
||||
tool_results.append(cast(conversation.ToolResultContent, chat_content))
|
||||
continue
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
tool_results.clear()
|
||||
|
||||
messages.append(
|
||||
_convert_content(
|
||||
cast(
|
||||
conversation.UserContent
|
||||
| conversation.SystemContent
|
||||
| conversation.AssistantContent,
|
||||
chat_content,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config={
|
||||
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
|
||||
),
|
||||
},
|
||||
tools=tools or None,
|
||||
system_instruction=prompt["parts"] if supports_system_instruction else None,
|
||||
system_instruction=prompt if supports_system_instruction else None,
|
||||
)
|
||||
|
||||
if not supports_system_instruction:
|
||||
messages = [
|
||||
{"role": "user", "parts": prompt["parts"]},
|
||||
{"role": "user", "parts": prompt},
|
||||
{"role": "model", "parts": "Ok"},
|
||||
*messages,
|
||||
]
|
||||
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
|
||||
content = " ".join(
|
||||
[part.text.strip() for part in chat_response.parts if part.text]
|
||||
)
|
||||
if content:
|
||||
session.async_add_message(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
function_calls = [
|
||||
part.function_call for part in chat_response.parts if part.function_call
|
||||
]
|
||||
|
||||
if not function_calls or not session.llm_api:
|
||||
break
|
||||
|
||||
tool_responses = []
|
||||
for function_call in function_calls:
|
||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||
tool_calls = []
|
||||
for part in chat_response.parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = _escape_decode(tool_call["args"])
|
||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
function_response = await session.async_call_tool(tool_input)
|
||||
tool_responses.append(
|
||||
protos.Part(
|
||||
function_response=protos.FunctionResponse(
|
||||
name=tool_name, response=function_response
|
||||
tool_calls.append(
|
||||
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
)
|
||||
|
||||
chat_request = _create_google_tool_response_content(
|
||||
[
|
||||
tool_response
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
)
|
||||
)
|
||||
chat_request = protos.Content(parts=tool_responses)
|
||||
session.async_add_message(
|
||||
conversation.NativeContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=chat_request,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
response.async_set_speech(
|
||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=response, conversation_id=session.conversation_id
|
||||
response=response, conversation_id=chat_log.conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
|
@ -70,7 +70,9 @@ def _format_tool(
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
def _convert_message_to_param(
|
||||
message: ChatCompletionMessage,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert from class to TypedDict."""
|
||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||
if message.tool_calls:
|
||||
@ -94,20 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
|
||||
return param
|
||||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.Content
|
||||
| conversation.NativeContent[ChatCompletionMessageParam],
|
||||
def _convert_content_to_param(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
role = message.role
|
||||
if role == "native":
|
||||
# mypy doesn't understand that checking role ensures content type
|
||||
return message.content # type: ignore[return-value]
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": role, "content": message.content},
|
||||
if content.role == "tool_result":
|
||||
assert type(content) is conversation.ToolResultContent
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
role = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
# Handle the Assistant content including tool calls.
|
||||
assert type(content) is conversation.AssistantContent
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content.content,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
name=tool_call.tool_name,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -171,14 +195,14 @@ class OpenAIConversationEntity(
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatLog[ChatCompletionMessageParam],
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
assert user_input.agent_id
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await session.async_update_llm_data(
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
@ -188,17 +212,14 @@ class OpenAIConversationEntity(
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if session.llm_api:
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, session.llm_api.custom_serializer)
|
||||
for tool in session.llm_api.tools
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
|
||||
messages = [
|
||||
_chat_message_convert(message) for message in session.async_get_messages()
|
||||
]
|
||||
messages = [_convert_content_to_param(content) for content in chat_log.content]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
@ -213,7 +234,7 @@ class OpenAIConversationEntity(
|
||||
),
|
||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
"user": session.conversation_id,
|
||||
"user": chat_log.conversation_id,
|
||||
}
|
||||
|
||||
if model.startswith("o"):
|
||||
@ -229,43 +250,39 @@ class OpenAIConversationEntity(
|
||||
|
||||
LOGGER.debug("Response %s", result)
|
||||
response = result.choices[0].message
|
||||
messages.append(_message_convert(response))
|
||||
messages.append(_convert_message_to_param(response))
|
||||
|
||||
session.async_add_message(
|
||||
conversation.Content(
|
||||
role=response.role,
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
),
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
if response.tool_calls:
|
||||
tool_calls = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
for tool_call in response.tool_calls
|
||||
]
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
_convert_content_to_param(tool_response)
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not response.tool_calls or not session.llm_api:
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
tool_response = await session.async_call_tool(tool_input)
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=tool_call.id,
|
||||
content=json.dumps(tool_response),
|
||||
)
|
||||
)
|
||||
session.async_add_message(
|
||||
conversation.NativeContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=messages[-1],
|
||||
)
|
||||
)
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=session.conversation_id
|
||||
response=intent_response, conversation_id=chat_log.conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field as dc_field
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
|
||||
from . import (
|
||||
area_registry as ar,
|
||||
@ -139,6 +140,8 @@ class ToolInput:
|
||||
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
# Using lambda for default to allow patching in tests
|
||||
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
|
||||
|
||||
|
||||
class Tool:
|
||||
|
@ -236,6 +236,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
@ -373,6 +374,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
|
@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.conversation import (
|
||||
Content,
|
||||
AssistantContent,
|
||||
ConversationInput,
|
||||
ConverseError,
|
||||
NativeContent,
|
||||
ToolResultContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
|
||||
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
@ -56,13 +56,13 @@ async def test_cleanup(
|
||||
):
|
||||
conversation_id = session.conversation_id
|
||||
# Add message so it persists
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
|
||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
@ -79,7 +79,7 @@ async def test_cleanup(
|
||||
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
|
||||
async def test_add_message(
|
||||
async def test_default_content(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
@ -87,95 +87,11 @@ async def test_add_message(
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
assert len(chat_log.messages) == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(
|
||||
Content(role="system", agent_id=None, content="")
|
||||
)
|
||||
|
||||
# No 2 user messages in a row
|
||||
assert chat_log.messages[1].role == "user"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
|
||||
|
||||
# No 2 assistant messages in a row
|
||||
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
|
||||
assert len(chat_log.messages) == 3
|
||||
assert chat_log.messages[-1].role == "assistant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(
|
||||
Content(role="assistant", agent_id=None, content="")
|
||||
)
|
||||
|
||||
|
||||
async def test_message_filtering(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
messages = chat_log.async_get_messages(agent_id=None)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == Content(
|
||||
role="system",
|
||||
agent_id=None,
|
||||
content="",
|
||||
)
|
||||
assert messages[1] == Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content=mock_conversation_input.text,
|
||||
)
|
||||
# Cannot add a second user message in a row
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
# Different agent, native messages will be filtered out.
|
||||
chat_log.async_add_message(
|
||||
NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||
)
|
||||
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
|
||||
# A non-native message from another agent is not filtered out.
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="another-mock-agent-id",
|
||||
content="Hi!",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(chat_log.messages) == 6
|
||||
|
||||
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
|
||||
assert len(messages) == 5
|
||||
|
||||
assert messages[2] == Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
|
||||
assert messages[4] == Content(
|
||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
||||
)
|
||||
assert len(chat_log.content) == 2
|
||||
assert chat_log.content[0].role == "system"
|
||||
assert chat_log.content[0].content == ""
|
||||
assert chat_log.content[1].role == "user"
|
||||
assert chat_log.content[1].content == mock_conversation_input.text
|
||||
|
||||
|
||||
async def test_llm_api(
|
||||
@ -268,12 +184,10 @@ async def test_template_variables(
|
||||
),
|
||||
)
|
||||
|
||||
assert chat_log.user_name == "Test User"
|
||||
|
||||
assert "The instance name is test home." in chat_log.messages[0].content
|
||||
assert "The user name is Test User." in chat_log.messages[0].content
|
||||
assert "The user id is 12345." in chat_log.messages[0].content
|
||||
assert "The calling platform is test." in chat_log.messages[0].content
|
||||
assert "The instance name is test home." in chat_log.content[0].content
|
||||
assert "The user name is Test User." in chat_log.content[0].content
|
||||
assert "The user id is 12345." in chat_log.content[0].content
|
||||
assert "The calling platform is test." in chat_log.content[0].content
|
||||
|
||||
|
||||
async def test_extra_systen_prompt(
|
||||
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
conversation_id = chat_log.conversation_id
|
||||
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||
|
||||
# Verify that we take new system prompts
|
||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
async for _tool_result in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
):
|
||||
pytest.fail("should not reach here")
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert extra_system_prompt not in chat_log.messages[0].content
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||
assert extra_system_prompt not in chat_log.content[0].content
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
mock_conversation_input.extra_system_prompt = None
|
||||
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||
|
||||
|
||||
async def test_tool_call(
|
||||
@ -383,8 +297,7 @@ async def test_tool_call(
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
||||
return_value=[],
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
@ -398,14 +311,29 @@ async def test_tool_call(
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_log.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
|
||||
assert result == "Test response"
|
||||
assert result == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_result="Test response",
|
||||
tool_name="test_tool",
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_call_exception(
|
||||
@ -423,8 +351,7 @@ async def test_tool_call_exception(
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
||||
return_value=[],
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
@ -438,11 +365,26 @@ async def test_tool_call_exception(
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_log.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
AssistantContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
|
||||
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
|
||||
assert result == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
|
||||
tool_name="test_tool",
|
||||
)
|
@ -36,6 +36,13 @@ def freeze_the_time():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid_tools():
|
||||
"""Mock generated ULIDs for tool calls."""
|
||||
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
||||
)
|
||||
@ -177,6 +184,7 @@ async def test_chat_history(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
@ -256,6 +264,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={
|
||||
"param1": ["test_value", "param1's value"],
|
||||
@ -287,9 +296,7 @@ async def test_function_call(
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
assert [
|
||||
p.function_response.name
|
||||
for p in detail_event["data"]["messages"][2]["content"].parts
|
||||
if p.function_response
|
||||
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
||||
] == ["test_tool"]
|
||||
|
||||
|
||||
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
),
|
||||
@ -451,6 +459,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": 1},
|
||||
),
|
||||
@ -605,6 +614,7 @@ async def test_template_variables(
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.text = "Model response"
|
||||
mock_part.function_call = None
|
||||
chat_response.parts = [mock_part]
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
||||
|
@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid_tools():
|
||||
"""Mock generated ULIDs for tool calls."""
|
||||
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
@ -205,6 +212,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args=expected_tool_args,
|
||||
),
|
||||
@ -285,6 +293,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
|
@ -195,6 +195,7 @@ async def test_function_call(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
@ -359,6 +360,7 @@ async def test_function_exception(
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user