mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Chat session rev2 (#137209)
* Chat Session rev 2 * Rename session to chat_log * Simplify typing * Typing * Address comments * Fix anthropic and ollama
This commit is contained in:
parent
ce93cb9467
commit
9679fc7878
@ -272,6 +272,7 @@ class AnthropicConversationEntity(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
|
id=tool_call.id,
|
||||||
tool_name=tool_call.name,
|
tool_name=tool_call.name,
|
||||||
tool_args=cast(dict[str, Any], tool_call.input),
|
tool_args=cast(dict[str, Any], tool_call.input),
|
||||||
)
|
)
|
||||||
|
@ -1063,11 +1063,11 @@ class PipelineRun:
|
|||||||
agent_id=self.intent_agent,
|
agent_id=self.intent_agent,
|
||||||
extra_system_prompt=conversation_extra_system_prompt,
|
extra_system_prompt=conversation_extra_system_prompt,
|
||||||
)
|
)
|
||||||
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
|
||||||
|
|
||||||
agent_id = user_input.agent_id
|
agent_id = self.intent_agent
|
||||||
|
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
|
||||||
intent_response: intent.IntentResponse | None = None
|
intent_response: intent.IntentResponse | None = None
|
||||||
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
|
if not processed_locally:
|
||||||
# Sentence triggers override conversation agent
|
# Sentence triggers override conversation agent
|
||||||
if (
|
if (
|
||||||
trigger_response_text
|
trigger_response_text
|
||||||
@ -1105,13 +1105,13 @@ class PipelineRun:
|
|||||||
speech: str = intent_response.speech.get("plain", {}).get(
|
speech: str = intent_response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
chat_log.async_add_message(
|
async for _ in chat_log.async_add_assistant_content(
|
||||||
conversation.Content(
|
conversation.AssistantContent(
|
||||||
role="assistant",
|
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
content=speech,
|
content=speech,
|
||||||
)
|
)
|
||||||
)
|
):
|
||||||
|
pass
|
||||||
conversation_result = conversation.ConversationResult(
|
conversation_result = conversation.ConversationResult(
|
||||||
response=intent_response,
|
response=intent_response,
|
||||||
conversation_id=session.conversation_id,
|
conversation_id=session.conversation_id,
|
||||||
|
@ -30,6 +30,16 @@ from .agent_manager import (
|
|||||||
async_get_agent,
|
async_get_agent,
|
||||||
get_agent_manager,
|
get_agent_manager,
|
||||||
)
|
)
|
||||||
|
from .chat_log import (
|
||||||
|
AssistantContent,
|
||||||
|
ChatLog,
|
||||||
|
Content,
|
||||||
|
ConverseError,
|
||||||
|
SystemContent,
|
||||||
|
ToolResultContent,
|
||||||
|
UserContent,
|
||||||
|
async_get_chat_log,
|
||||||
|
)
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_AGENT_ID,
|
ATTR_AGENT_ID,
|
||||||
ATTR_CONVERSATION_ID,
|
ATTR_CONVERSATION_ID,
|
||||||
@ -48,13 +58,13 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
|||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
|
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"HOME_ASSISTANT_AGENT",
|
"HOME_ASSISTANT_AGENT",
|
||||||
"OLD_HOME_ASSISTANT_AGENT",
|
"OLD_HOME_ASSISTANT_AGENT",
|
||||||
|
"AssistantContent",
|
||||||
"ChatLog",
|
"ChatLog",
|
||||||
"Content",
|
"Content",
|
||||||
"ConversationEntity",
|
"ConversationEntity",
|
||||||
@ -63,7 +73,9 @@ __all__ = [
|
|||||||
"ConversationResult",
|
"ConversationResult",
|
||||||
"ConversationTraceEventType",
|
"ConversationTraceEventType",
|
||||||
"ConverseError",
|
"ConverseError",
|
||||||
"NativeContent",
|
"SystemContent",
|
||||||
|
"ToolResultContent",
|
||||||
|
"UserContent",
|
||||||
"async_conversation_trace_append",
|
"async_conversation_trace_append",
|
||||||
"async_converse",
|
"async_converse",
|
||||||
"async_get_agent_info",
|
"async_get_agent_info",
|
||||||
|
@ -2,19 +2,16 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
from homeassistant.helpers import chat_session, intent, llm, template
|
from homeassistant.helpers import chat_session, intent, llm, template
|
||||||
from homeassistant.util import dt as dt_util
|
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
from homeassistant.util.json import JsonObjectType
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
|
||||||
@ -31,7 +28,7 @@ LOGGER = logging.getLogger(__name__)
|
|||||||
def async_get_chat_log(
|
def async_get_chat_log(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
session: chat_session.ChatSession,
|
session: chat_session.ChatSession,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput | None = None,
|
||||||
) -> Generator[ChatLog]:
|
) -> Generator[ChatLog]:
|
||||||
"""Return chat log for a specific chat session."""
|
"""Return chat log for a specific chat session."""
|
||||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||||
@ -42,9 +39,9 @@ def async_get_chat_log(
|
|||||||
history = all_history.get(session.conversation_id)
|
history = all_history.get(session.conversation_id)
|
||||||
|
|
||||||
if history:
|
if history:
|
||||||
history = replace(history, messages=history.messages.copy())
|
history = replace(history, content=history.content.copy())
|
||||||
else:
|
else:
|
||||||
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
|
history = ChatLog(hass, session.conversation_id)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def do_cleanup() -> None:
|
def do_cleanup() -> None:
|
||||||
@ -53,22 +50,19 @@ def async_get_chat_log(
|
|||||||
|
|
||||||
session.async_on_cleanup(do_cleanup)
|
session.async_on_cleanup(do_cleanup)
|
||||||
|
|
||||||
message: Content = Content(
|
if user_input is not None:
|
||||||
role="user",
|
history.async_add_user_content(UserContent(content=user_input.text))
|
||||||
agent_id=user_input.agent_id,
|
|
||||||
content=user_input.text,
|
last_message = history.content[-1]
|
||||||
)
|
|
||||||
history.async_add_message(message)
|
|
||||||
|
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
if history.messages[-1] is message:
|
if history.content[-1] is last_message:
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"History opened but no assistant message was added, ignoring update"
|
"History opened but no assistant message was added, ignoring update"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
history.last_updated = dt_util.utcnow()
|
|
||||||
all_history[session.conversation_id] = history
|
all_history[session.conversation_id] = history
|
||||||
|
|
||||||
|
|
||||||
@ -94,63 +88,94 @@ class ConverseError(HomeAssistantError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class Content:
|
class SystemContent:
|
||||||
"""Base class for chat messages."""
|
"""Base class for chat messages."""
|
||||||
|
|
||||||
role: Literal["system", "assistant", "user"]
|
role: str = field(init=False, default="system")
|
||||||
agent_id: str | None
|
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NativeContent[_NativeT]:
|
class UserContent:
|
||||||
"""Native content."""
|
"""Assistant content."""
|
||||||
|
|
||||||
role: str = field(init=False, default="native")
|
role: str = field(init=False, default="user")
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AssistantContent:
|
||||||
|
"""Assistant content."""
|
||||||
|
|
||||||
|
role: str = field(init=False, default="assistant")
|
||||||
agent_id: str
|
agent_id: str
|
||||||
content: _NativeT
|
content: str
|
||||||
|
tool_calls: list[llm.ToolInput] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ToolResultContent:
|
||||||
|
"""Tool result content."""
|
||||||
|
|
||||||
|
role: str = field(init=False, default="tool_result")
|
||||||
|
agent_id: str
|
||||||
|
tool_call_id: str
|
||||||
|
tool_name: str
|
||||||
|
tool_result: JsonObjectType
|
||||||
|
|
||||||
|
|
||||||
|
Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatLog[_NativeT]:
|
class ChatLog:
|
||||||
"""Class holding the chat history of a specific conversation."""
|
"""Class holding the chat history of a specific conversation."""
|
||||||
|
|
||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
agent_id: str | None
|
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
|
||||||
user_name: str | None = None
|
|
||||||
messages: list[Content | NativeContent[_NativeT]] = field(
|
|
||||||
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
|
|
||||||
)
|
|
||||||
extra_system_prompt: str | None = None
|
extra_system_prompt: str | None = None
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
|
def async_add_user_content(self, content: UserContent) -> None:
|
||||||
"""Process intent."""
|
"""Add user content to the log."""
|
||||||
if message.role == "system":
|
self.content.append(content)
|
||||||
raise ValueError("Cannot add system messages to history")
|
|
||||||
if message.role != "native" and self.messages[-1].role == message.role:
|
|
||||||
raise ValueError("Cannot add two assistant or user messages in a row")
|
|
||||||
|
|
||||||
self.messages.append(message)
|
async def async_add_assistant_content(
|
||||||
|
self, content: AssistantContent
|
||||||
|
) -> AsyncGenerator[ToolResultContent]:
|
||||||
|
"""Add assistant content."""
|
||||||
|
self.content.append(content)
|
||||||
|
|
||||||
@callback
|
if content.tool_calls is None:
|
||||||
def async_get_messages(
|
return
|
||||||
self, agent_id: str | None = None
|
|
||||||
) -> list[Content | NativeContent[_NativeT]]:
|
|
||||||
"""Get messages for a specific agent ID.
|
|
||||||
|
|
||||||
This will filter out any native message tied to other agent IDs.
|
if self.llm_api is None:
|
||||||
It can still include assistant/user messages generated by other agents.
|
raise ValueError("No LLM API configured")
|
||||||
"""
|
|
||||||
return [
|
for tool_input in content.tool_calls:
|
||||||
message
|
LOGGER.debug(
|
||||||
for message in self.messages
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
if message.role != "native" or message.agent_id == agent_id
|
)
|
||||||
]
|
|
||||||
|
try:
|
||||||
|
tool_result = await self.llm_api.async_call_tool(tool_input)
|
||||||
|
except (HomeAssistantError, vol.Invalid) as e:
|
||||||
|
tool_result = {"error": type(e).__name__}
|
||||||
|
if str(e):
|
||||||
|
tool_result["error_text"] = str(e)
|
||||||
|
LOGGER.debug("Tool response: %s", tool_result)
|
||||||
|
|
||||||
|
response_content = ToolResultContent(
|
||||||
|
agent_id=content.agent_id,
|
||||||
|
tool_call_id=tool_input.id,
|
||||||
|
tool_name=tool_input.tool_name,
|
||||||
|
tool_result=tool_result,
|
||||||
|
)
|
||||||
|
self.content.append(response_content)
|
||||||
|
yield response_content
|
||||||
|
|
||||||
async def async_update_llm_data(
|
async def async_update_llm_data(
|
||||||
self,
|
self,
|
||||||
@ -250,36 +275,16 @@ class ChatLog[_NativeT]:
|
|||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
self.llm_api = llm_api
|
self.llm_api = llm_api
|
||||||
self.user_name = user_name
|
|
||||||
self.extra_system_prompt = extra_system_prompt
|
self.extra_system_prompt = extra_system_prompt
|
||||||
self.messages[0] = Content(
|
self.content[0] = SystemContent(content=prompt)
|
||||||
role="system",
|
|
||||||
agent_id=user_input.agent_id,
|
|
||||||
content=prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER.debug("Prompt: %s", self.messages)
|
LOGGER.debug("Prompt: %s", self.content)
|
||||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||||
|
|
||||||
trace.async_conversation_trace_append(
|
trace.async_conversation_trace_append(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
{
|
{
|
||||||
"messages": self.messages,
|
"messages": self.content,
|
||||||
"tools": self.llm_api.tools if self.llm_api else None,
|
"tools": self.llm_api.tools if self.llm_api else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
|
|
||||||
"""Invoke LLM tool for the configured LLM API."""
|
|
||||||
if not self.llm_api:
|
|
||||||
raise ValueError("No LLM API configured")
|
|
||||||
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_response = await self.llm_api.async_call_tool(tool_input)
|
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
|
||||||
tool_response = {"error": type(e).__name__}
|
|
||||||
if str(e):
|
|
||||||
tool_response["error_text"] = str(e)
|
|
||||||
LOGGER.debug("Tool response: %s", tool_response)
|
|
||||||
return tool_response
|
|
@ -55,6 +55,7 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||||||
from homeassistant.helpers.event import async_track_state_added_domain
|
from homeassistant.helpers.event import async_track_state_added_domain
|
||||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||||
|
|
||||||
|
from .chat_log import AssistantContent, async_get_chat_log
|
||||||
from .const import (
|
from .const import (
|
||||||
DATA_DEFAULT_ENTITY,
|
DATA_DEFAULT_ENTITY,
|
||||||
DEFAULT_EXPOSED_ATTRIBUTES,
|
DEFAULT_EXPOSED_ATTRIBUTES,
|
||||||
@ -63,7 +64,6 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
from .session import Content, async_get_chat_log
|
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -379,13 +379,13 @@ class DefaultAgent(ConversationEntity):
|
|||||||
)
|
)
|
||||||
|
|
||||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||||
chat_log.async_add_message(
|
async for _tool_result in chat_log.async_add_assistant_content(
|
||||||
Content(
|
AssistantContent(
|
||||||
role="assistant",
|
agent_id=user_input.agent_id, # type: ignore[arg-type]
|
||||||
agent_id=user_input.agent_id,
|
|
||||||
content=speech,
|
content=speech,
|
||||||
)
|
)
|
||||||
)
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
return ConversationResult(
|
return ConversationResult(
|
||||||
response=response, conversation_id=session.conversation_id
|
response=response, conversation_id=session.conversation_id
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from google.api_core.exceptions import GoogleAPIError
|
from google.api_core.exceptions import GoogleAPIError
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
@ -149,15 +149,53 @@ def _escape_decode(value: Any) -> Any:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _chat_message_convert(
|
def _create_google_tool_response_content(
|
||||||
message: conversation.Content | conversation.NativeContent[genai_types.ContentDict],
|
content: list[conversation.ToolResultContent],
|
||||||
) -> genai_types.ContentDict:
|
) -> protos.Content:
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
"""Create a Google tool response content."""
|
||||||
if message.role == "native":
|
return protos.Content(
|
||||||
return message.content
|
parts=[
|
||||||
|
protos.Part(
|
||||||
|
function_response=protos.FunctionResponse(
|
||||||
|
name=tool_result.tool_name, response=tool_result.tool_result
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool_result in content
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
role = "model" if message.role == "assistant" else message.role
|
|
||||||
return {"role": role, "parts": message.content}
|
def _convert_content(
|
||||||
|
content: conversation.UserContent
|
||||||
|
| conversation.AssistantContent
|
||||||
|
| conversation.SystemContent,
|
||||||
|
) -> genai_types.ContentDict:
|
||||||
|
"""Convert HA content to Google content."""
|
||||||
|
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||||
|
role = "model" if content.role == "assistant" else content.role
|
||||||
|
return {"role": role, "parts": content.content}
|
||||||
|
|
||||||
|
# Handle the Assistant content with tool calls.
|
||||||
|
assert type(content) is conversation.AssistantContent
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if content.content:
|
||||||
|
parts.append(protos.Part(text=content.content))
|
||||||
|
|
||||||
|
if content.tool_calls:
|
||||||
|
parts.extend(
|
||||||
|
[
|
||||||
|
protos.Part(
|
||||||
|
function_call=protos.FunctionCall(
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
args=_escape_decode(tool_call.tool_args),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return protos.Content({"role": "model", "parts": parts})
|
||||||
|
|
||||||
|
|
||||||
class GoogleGenerativeAIConversationEntity(
|
class GoogleGenerativeAIConversationEntity(
|
||||||
@ -220,7 +258,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
async def _async_handle_message(
|
async def _async_handle_message(
|
||||||
self,
|
self,
|
||||||
user_input: conversation.ConversationInput,
|
user_input: conversation.ConversationInput,
|
||||||
session: conversation.ChatLog[genai_types.ContentDict],
|
chat_log: conversation.ChatLog,
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Call the API."""
|
"""Call the API."""
|
||||||
|
|
||||||
@ -228,7 +266,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await session.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
user_input,
|
user_input,
|
||||||
options.get(CONF_LLM_HASS_API),
|
options.get(CONF_LLM_HASS_API),
|
||||||
@ -238,10 +276,10 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
if session.llm_api:
|
if chat_log.llm_api:
|
||||||
tools = [
|
tools = [
|
||||||
_format_tool(tool, session.llm_api.custom_serializer)
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
for tool in session.llm_api.tools
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
@ -252,9 +290,36 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt, *messages = [
|
prompt = chat_log.content[0].content # type: ignore[union-attr]
|
||||||
_chat_message_convert(message) for message in session.async_get_messages()
|
messages: list[genai_types.ContentDict] = []
|
||||||
]
|
|
||||||
|
# Google groups tool results, we do not. Group them before sending.
|
||||||
|
tool_results: list[conversation.ToolResultContent] = []
|
||||||
|
|
||||||
|
for chat_content in chat_log.content[1:]:
|
||||||
|
if chat_content.role == "tool_result":
|
||||||
|
# mypy doesn't like picking a type based on checking shared property 'role'
|
||||||
|
tool_results.append(cast(conversation.ToolResultContent, chat_content))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tool_results:
|
||||||
|
messages.append(_create_google_tool_response_content(tool_results))
|
||||||
|
tool_results.clear()
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
_convert_content(
|
||||||
|
cast(
|
||||||
|
conversation.UserContent
|
||||||
|
| conversation.SystemContent
|
||||||
|
| conversation.AssistantContent,
|
||||||
|
chat_content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_results:
|
||||||
|
messages.append(_create_google_tool_response_content(tool_results))
|
||||||
|
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
generation_config={
|
generation_config={
|
||||||
@ -282,12 +347,12 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
tools=tools or None,
|
tools=tools or None,
|
||||||
system_instruction=prompt["parts"] if supports_system_instruction else None,
|
system_instruction=prompt if supports_system_instruction else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not supports_system_instruction:
|
if not supports_system_instruction:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "parts": prompt["parts"]},
|
{"role": "user", "parts": prompt},
|
||||||
{"role": "model", "parts": "Ok"},
|
{"role": "model", "parts": "Ok"},
|
||||||
*messages,
|
*messages,
|
||||||
]
|
]
|
||||||
@ -325,50 +390,40 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
content = " ".join(
|
content = " ".join(
|
||||||
[part.text.strip() for part in chat_response.parts if part.text]
|
[part.text.strip() for part in chat_response.parts if part.text]
|
||||||
)
|
)
|
||||||
if content:
|
|
||||||
session.async_add_message(
|
|
||||||
conversation.Content(
|
|
||||||
role="assistant",
|
|
||||||
agent_id=user_input.agent_id,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
function_calls = [
|
tool_calls = []
|
||||||
part.function_call for part in chat_response.parts if part.function_call
|
for part in chat_response.parts:
|
||||||
]
|
if not part.function_call:
|
||||||
|
continue
|
||||||
if not function_calls or not session.llm_api:
|
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
|
||||||
break
|
|
||||||
|
|
||||||
tool_responses = []
|
|
||||||
for function_call in function_calls:
|
|
||||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
|
||||||
tool_name = tool_call["name"]
|
tool_name = tool_call["name"]
|
||||||
tool_args = _escape_decode(tool_call["args"])
|
tool_args = _escape_decode(tool_call["args"])
|
||||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
tool_calls.append(
|
||||||
function_response = await session.async_call_tool(tool_input)
|
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||||
tool_responses.append(
|
|
||||||
protos.Part(
|
|
||||||
function_response=protos.FunctionResponse(
|
|
||||||
name=tool_name, response=function_response
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
)
|
chat_request = _create_google_tool_response_content(
|
||||||
chat_request = protos.Content(parts=tool_responses)
|
[
|
||||||
session.async_add_message(
|
tool_response
|
||||||
conversation.NativeContent(
|
async for tool_response in chat_log.async_add_assistant_content(
|
||||||
|
conversation.AssistantContent(
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content=chat_request,
|
content=content,
|
||||||
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
response = intent.IntentResponse(language=user_input.language)
|
response = intent.IntentResponse(language=user_input.language)
|
||||||
response.async_set_speech(
|
response.async_set_speech(
|
||||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=response, conversation_id=session.conversation_id
|
response=response, conversation_id=chat_log.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
|
@ -70,7 +70,9 @@ def _format_tool(
|
|||||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
def _convert_message_to_param(
|
||||||
|
message: ChatCompletionMessage,
|
||||||
|
) -> ChatCompletionMessageParam:
|
||||||
"""Convert from class to TypedDict."""
|
"""Convert from class to TypedDict."""
|
||||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
@ -94,20 +96,42 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
|
|||||||
return param
|
return param
|
||||||
|
|
||||||
|
|
||||||
def _chat_message_convert(
|
def _convert_content_to_param(
|
||||||
message: conversation.Content
|
content: conversation.Content,
|
||||||
| conversation.NativeContent[ChatCompletionMessageParam],
|
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
role = message.role
|
if content.role == "tool_result":
|
||||||
if role == "native":
|
assert type(content) is conversation.ToolResultContent
|
||||||
# mypy doesn't understand that checking role ensures content type
|
return ChatCompletionToolMessageParam(
|
||||||
return message.content # type: ignore[return-value]
|
role="tool",
|
||||||
|
tool_call_id=content.tool_call_id,
|
||||||
|
content=json.dumps(content.tool_result),
|
||||||
|
)
|
||||||
|
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||||
|
role = content.role
|
||||||
if role == "system":
|
if role == "system":
|
||||||
role = "developer"
|
role = "developer"
|
||||||
return cast(
|
return cast(
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
{"role": role, "content": message.content},
|
{"role": content.role, "content": content.content}, # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle the Assistant content including tool calls.
|
||||||
|
assert type(content) is conversation.AssistantContent
|
||||||
|
return ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=content.content,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCallParam(
|
||||||
|
id=tool_call.id,
|
||||||
|
function=Function(
|
||||||
|
arguments=json.dumps(tool_call.tool_args),
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -171,14 +195,14 @@ class OpenAIConversationEntity(
|
|||||||
async def _async_handle_message(
|
async def _async_handle_message(
|
||||||
self,
|
self,
|
||||||
user_input: conversation.ConversationInput,
|
user_input: conversation.ConversationInput,
|
||||||
session: conversation.ChatLog[ChatCompletionMessageParam],
|
chat_log: conversation.ChatLog,
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Call the API."""
|
"""Call the API."""
|
||||||
assert user_input.agent_id
|
assert user_input.agent_id
|
||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await session.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
user_input,
|
user_input,
|
||||||
options.get(CONF_LLM_HASS_API),
|
options.get(CONF_LLM_HASS_API),
|
||||||
@ -188,17 +212,14 @@ class OpenAIConversationEntity(
|
|||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
|
||||||
tools: list[ChatCompletionToolParam] | None = None
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
if session.llm_api:
|
if chat_log.llm_api:
|
||||||
tools = [
|
tools = [
|
||||||
_format_tool(tool, session.llm_api.custom_serializer)
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
for tool in session.llm_api.tools
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
messages = [_convert_content_to_param(content) for content in chat_log.content]
|
||||||
messages = [
|
|
||||||
_chat_message_convert(message) for message in session.async_get_messages()
|
|
||||||
]
|
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
@ -213,7 +234,7 @@ class OpenAIConversationEntity(
|
|||||||
),
|
),
|
||||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
"user": session.conversation_id,
|
"user": chat_log.conversation_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.startswith("o"):
|
if model.startswith("o"):
|
||||||
@ -229,43 +250,39 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
LOGGER.debug("Response %s", result)
|
LOGGER.debug("Response %s", result)
|
||||||
response = result.choices[0].message
|
response = result.choices[0].message
|
||||||
messages.append(_message_convert(response))
|
messages.append(_convert_message_to_param(response))
|
||||||
|
|
||||||
session.async_add_message(
|
tool_calls: list[llm.ToolInput] | None = None
|
||||||
conversation.Content(
|
if response.tool_calls:
|
||||||
role=response.role,
|
tool_calls = [
|
||||||
agent_id=user_input.agent_id,
|
llm.ToolInput(
|
||||||
content=response.content or "",
|
id=tool_call.id,
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response.tool_calls or not session.llm_api:
|
|
||||||
break
|
|
||||||
|
|
||||||
for tool_call in response.tool_calls:
|
|
||||||
tool_input = llm.ToolInput(
|
|
||||||
tool_name=tool_call.function.name,
|
tool_name=tool_call.function.name,
|
||||||
tool_args=json.loads(tool_call.function.arguments),
|
tool_args=json.loads(tool_call.function.arguments),
|
||||||
)
|
)
|
||||||
tool_response = await session.async_call_tool(tool_input)
|
for tool_call in response.tool_calls
|
||||||
messages.append(
|
]
|
||||||
ChatCompletionToolMessageParam(
|
|
||||||
role="tool",
|
messages.extend(
|
||||||
tool_call_id=tool_call.id,
|
[
|
||||||
content=json.dumps(tool_response),
|
_convert_content_to_param(tool_response)
|
||||||
)
|
async for tool_response in chat_log.async_add_assistant_content(
|
||||||
)
|
conversation.AssistantContent(
|
||||||
session.async_add_message(
|
|
||||||
conversation.NativeContent(
|
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content=messages[-1],
|
content=response.content or "",
|
||||||
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response.content or "")
|
intent_response.async_set_speech(response.content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=session.conversation_id
|
response=intent_response, conversation_id=chat_log.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field as dc_field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -36,6 +36,7 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
from homeassistant.util.json import JsonObjectType
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
from homeassistant.util.ulid import ulid_now
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
@ -139,6 +140,8 @@ class ToolInput:
|
|||||||
|
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_args: dict[str, Any]
|
tool_args: dict[str, Any]
|
||||||
|
# Using lambda for default to allow patching in tests
|
||||||
|
id: str = dc_field(default_factory=lambda: ulid_now()) # pylint: disable=unnecessary-lambda
|
||||||
|
|
||||||
|
|
||||||
class Tool:
|
class Tool:
|
||||||
|
@ -236,6 +236,7 @@ async def test_function_call(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
@ -373,6 +374,7 @@ async def test_function_exception(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="toolu_0123456789AbCdEfGhIjKlM",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
|
@ -9,13 +9,13 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.conversation import (
|
from homeassistant.components.conversation import (
|
||||||
Content,
|
AssistantContent,
|
||||||
ConversationInput,
|
ConversationInput,
|
||||||
ConverseError,
|
ConverseError,
|
||||||
NativeContent,
|
ToolResultContent,
|
||||||
async_get_chat_log,
|
async_get_chat_log,
|
||||||
)
|
)
|
||||||
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
|
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session, llm
|
from homeassistant.helpers import chat_session, llm
|
||||||
@ -40,7 +40,7 @@ def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ulid() -> Generator[Mock]:
|
def mock_ulid() -> Generator[Mock]:
|
||||||
"""Mock the ulid library."""
|
"""Mock the ulid library."""
|
||||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||||
mock_ulid_now.return_value = "mock-ulid"
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
yield mock_ulid_now
|
yield mock_ulid_now
|
||||||
|
|
||||||
@ -56,13 +56,13 @@ async def test_cleanup(
|
|||||||
):
|
):
|
||||||
conversation_id = session.conversation_id
|
conversation_id = session.conversation_id
|
||||||
# Add message so it persists
|
# Add message so it persists
|
||||||
chat_log.async_add_message(
|
async for _tool_result in chat_log.async_add_assistant_content(
|
||||||
Content(
|
AssistantContent(
|
||||||
role="assistant",
|
agent_id="mock-agent-id",
|
||||||
agent_id=mock_conversation_input.agent_id,
|
content="Hey!",
|
||||||
content="",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
pytest.fail("should not reach here")
|
||||||
|
|
||||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ async def test_cleanup(
|
|||||||
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
||||||
|
|
||||||
|
|
||||||
async def test_add_message(
|
async def test_default_content(
|
||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test filtering of messages."""
|
"""Test filtering of messages."""
|
||||||
@ -87,95 +87,11 @@ async def test_add_message(
|
|||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
assert len(chat_log.messages) == 2
|
assert len(chat_log.content) == 2
|
||||||
|
assert chat_log.content[0].role == "system"
|
||||||
with pytest.raises(ValueError):
|
assert chat_log.content[0].content == ""
|
||||||
chat_log.async_add_message(
|
assert chat_log.content[1].role == "user"
|
||||||
Content(role="system", agent_id=None, content="")
|
assert chat_log.content[1].content == mock_conversation_input.text
|
||||||
)
|
|
||||||
|
|
||||||
# No 2 user messages in a row
|
|
||||||
assert chat_log.messages[1].role == "user"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
|
|
||||||
|
|
||||||
# No 2 assistant messages in a row
|
|
||||||
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
|
|
||||||
assert len(chat_log.messages) == 3
|
|
||||||
assert chat_log.messages[-1].role == "assistant"
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
chat_log.async_add_message(
|
|
||||||
Content(role="assistant", agent_id=None, content="")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_message_filtering(
|
|
||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
|
||||||
) -> None:
|
|
||||||
"""Test filtering of messages."""
|
|
||||||
with (
|
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
|
||||||
):
|
|
||||||
messages = chat_log.async_get_messages(agent_id=None)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert messages[0] == Content(
|
|
||||||
role="system",
|
|
||||||
agent_id=None,
|
|
||||||
content="",
|
|
||||||
)
|
|
||||||
assert messages[1] == Content(
|
|
||||||
role="user",
|
|
||||||
agent_id="mock-agent-id",
|
|
||||||
content=mock_conversation_input.text,
|
|
||||||
)
|
|
||||||
# Cannot add a second user message in a row
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
chat_log.async_add_message(
|
|
||||||
Content(
|
|
||||||
role="user",
|
|
||||||
agent_id="mock-agent-id",
|
|
||||||
content="Hey!",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_log.async_add_message(
|
|
||||||
Content(
|
|
||||||
role="assistant",
|
|
||||||
agent_id="mock-agent-id",
|
|
||||||
content="Hey!",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Different agent, native messages will be filtered out.
|
|
||||||
chat_log.async_add_message(
|
|
||||||
NativeContent(agent_id="another-mock-agent-id", content=1)
|
|
||||||
)
|
|
||||||
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
|
|
||||||
# A non-native message from another agent is not filtered out.
|
|
||||||
chat_log.async_add_message(
|
|
||||||
Content(
|
|
||||||
role="assistant",
|
|
||||||
agent_id="another-mock-agent-id",
|
|
||||||
content="Hi!",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(chat_log.messages) == 6
|
|
||||||
|
|
||||||
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
|
|
||||||
assert len(messages) == 5
|
|
||||||
|
|
||||||
assert messages[2] == Content(
|
|
||||||
role="assistant",
|
|
||||||
agent_id="mock-agent-id",
|
|
||||||
content="Hey!",
|
|
||||||
)
|
|
||||||
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
|
|
||||||
assert messages[4] == Content(
|
|
||||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_llm_api(
|
async def test_llm_api(
|
||||||
@ -268,12 +184,10 @@ async def test_template_variables(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_log.user_name == "Test User"
|
assert "The instance name is test home." in chat_log.content[0].content
|
||||||
|
assert "The user name is Test User." in chat_log.content[0].content
|
||||||
assert "The instance name is test home." in chat_log.messages[0].content
|
assert "The user id is 12345." in chat_log.content[0].content
|
||||||
assert "The user name is Test User." in chat_log.messages[0].content
|
assert "The calling platform is test." in chat_log.content[0].content
|
||||||
assert "The user id is 12345." in chat_log.messages[0].content
|
|
||||||
assert "The calling platform is test." in chat_log.messages[0].content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_extra_systen_prompt(
|
async def test_extra_systen_prompt(
|
||||||
@ -296,16 +210,16 @@ async def test_extra_systen_prompt(
|
|||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
chat_log.async_add_message(
|
async for _tool_result in chat_log.async_add_assistant_content(
|
||||||
Content(
|
AssistantContent(
|
||||||
role="assistant",
|
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
)
|
):
|
||||||
|
pytest.fail("should not reach here")
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||||
|
|
||||||
# Verify that follow-up conversations with no system prompt take previous one
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
conversation_id = chat_log.conversation_id
|
conversation_id = chat_log.conversation_id
|
||||||
@ -323,7 +237,7 @@ async def test_extra_systen_prompt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
assert chat_log.content[0].content.endswith(extra_system_prompt)
|
||||||
|
|
||||||
# Verify that we take new system prompts
|
# Verify that we take new system prompts
|
||||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||||
@ -338,17 +252,17 @@ async def test_extra_systen_prompt(
|
|||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
chat_log.async_add_message(
|
async for _tool_result in chat_log.async_add_assistant_content(
|
||||||
Content(
|
AssistantContent(
|
||||||
role="assistant",
|
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
)
|
)
|
||||||
)
|
):
|
||||||
|
pytest.fail("should not reach here")
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||||
assert extra_system_prompt not in chat_log.messages[0].content
|
assert extra_system_prompt not in chat_log.content[0].content
|
||||||
|
|
||||||
# Verify that follow-up conversations with no system prompt take previous one
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
mock_conversation_input.extra_system_prompt = None
|
mock_conversation_input.extra_system_prompt = None
|
||||||
@ -365,7 +279,7 @@ async def test_extra_systen_prompt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||||
|
|
||||||
|
|
||||||
async def test_tool_call(
|
async def test_tool_call(
|
||||||
@ -383,8 +297,7 @@ async def test_tool_call(
|
|||||||
mock_tool.async_call.return_value = "Test response"
|
mock_tool.async_call.return_value = "Test response"
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||||
return_value=[],
|
|
||||||
) as mock_get_tools:
|
) as mock_get_tools:
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
@ -398,14 +311,29 @@ async def test_tool_call(
|
|||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
result = await chat_log.async_call_tool(
|
result = None
|
||||||
|
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||||
|
AssistantContent(
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "Test Param"},
|
tool_args={"param1": "Test Param"},
|
||||||
)
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
assert result is None
|
||||||
|
result = tool_result_content
|
||||||
|
|
||||||
assert result == "Test response"
|
assert result == ToolResultContent(
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
tool_call_id="mock-tool-call-id",
|
||||||
|
tool_result="Test response",
|
||||||
|
tool_name="test_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_tool_call_exception(
|
async def test_tool_call_exception(
|
||||||
@ -423,8 +351,7 @@ async def test_tool_call_exception(
|
|||||||
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||||
return_value=[],
|
|
||||||
) as mock_get_tools:
|
) as mock_get_tools:
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
@ -438,11 +365,26 @@ async def test_tool_call_exception(
|
|||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
result = await chat_log.async_call_tool(
|
result = None
|
||||||
|
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||||
|
AssistantContent(
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "Test Param"},
|
tool_args={"param1": "Test Param"},
|
||||||
)
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
assert result is None
|
||||||
|
result = tool_result_content
|
||||||
|
|
||||||
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
|
assert result == ToolResultContent(
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
tool_call_id="mock-tool-call-id",
|
||||||
|
tool_result={"error": "HomeAssistantError", "error_text": "Test error"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
)
|
@ -36,6 +36,13 @@ def freeze_the_time():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_ulid_tools():
|
||||||
|
"""Mock generated ULIDs for tool calls."""
|
||||||
|
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
||||||
)
|
)
|
||||||
@ -177,6 +184,7 @@ async def test_chat_history(
|
|||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -256,6 +264,7 @@ async def test_function_call(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={
|
tool_args={
|
||||||
"param1": ["test_value", "param1's value"],
|
"param1": ["test_value", "param1's value"],
|
||||||
@ -287,9 +296,7 @@ async def test_function_call(
|
|||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||||
assert [
|
assert [
|
||||||
p.function_response.name
|
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
||||||
for p in detail_event["data"]["messages"][2]["content"].parts
|
|
||||||
if p.function_response
|
|
||||||
] == ["test_tool"]
|
] == ["test_tool"]
|
||||||
|
|
||||||
|
|
||||||
@ -362,6 +369,7 @@ async def test_function_call_without_parameters(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={},
|
tool_args={},
|
||||||
),
|
),
|
||||||
@ -451,6 +459,7 @@ async def test_function_exception(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": 1},
|
tool_args={"param1": 1},
|
||||||
),
|
),
|
||||||
@ -605,6 +614,7 @@ async def test_template_variables(
|
|||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
mock_part.text = "Model response"
|
mock_part.text = "Model response"
|
||||||
|
mock_part.function_call = None
|
||||||
chat_response.parts = [mock_part]
|
chat_response.parts = [mock_part]
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
||||||
|
@ -18,6 +18,13 @@ from homeassistant.helpers import intent, llm
|
|||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_ulid_tools():
|
||||||
|
"""Mock generated ULIDs for tool calls."""
|
||||||
|
with patch("homeassistant.helpers.llm.ulid_now", return_value="mock-tool-call"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||||
async def test_chat(
|
async def test_chat(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -205,6 +212,7 @@ async def test_function_call(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args=expected_tool_args,
|
tool_args=expected_tool_args,
|
||||||
),
|
),
|
||||||
@ -285,6 +293,7 @@ async def test_function_exception(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
|
@ -195,6 +195,7 @@ async def test_function_call(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
@ -359,6 +360,7 @@ async def test_function_exception(
|
|||||||
mock_tool.async_call.assert_awaited_once_with(
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
|
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user