Fix unnecessary single quotes escaping in Google AI (#118522)

This commit is contained in:
tronikos 2024-05-30 16:56:06 -07:00 committed by GitHub
parent 0d6c7d0973
commit 272c51fb38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 17 deletions

View File

@ -8,6 +8,7 @@ import google.ai.generativelanguage as glm
from google.api_core.exceptions import GoogleAPICallError
import google.generativeai as genai
import google.generativeai.types as genai_types
from google.protobuf.json_format import MessageToDict
import voluptuous as vol
from voluptuous_openapi import convert
@ -105,6 +106,17 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
)
def _adjust_value(value: Any) -> Any:
"""Reverse unnecessary single quotes escaping."""
if isinstance(value, str):
return value.replace("\\'", "'")
if isinstance(value, list):
return [_adjust_value(item) for item in value]
if isinstance(value, dict):
return {k: _adjust_value(v) for k, v in value.items()}
return value
class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -295,21 +307,22 @@ class GoogleGenerativeAIConversationEntity(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
tool_calls = [
function_calls = [
part.function_call for part in chat_response.parts if part.function_call
]
if not tool_calls or not llm_api:
if not function_calls or not llm_api:
break
tool_responses = []
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.name,
tool_args=dict(tool_call.args),
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = {
key: _adjust_value(value)
for key, value in tool_call["args"].items()
}
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
try:
function_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
@ -321,7 +334,7 @@ class GoogleGenerativeAIConversationEntity(
tool_responses.append(
glm.Part(
function_response=glm.FunctionResponse(
name=tool_call.name, response=function_response
name=tool_name, response=function_response
)
)
)

View File

@ -140,7 +140,7 @@ class APIInstance:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL,
{"tool_name": tool_input.tool_name, "tool_args": str(tool_input.tool_args)},
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
)
for tool in self.tools:

View File

@ -3,6 +3,7 @@
from unittest.mock import AsyncMock, MagicMock, patch
from freezegun import freeze_time
from google.ai.generativelanguage_v1beta.types.content import FunctionCall
from google.api_core.exceptions import GoogleAPICallError
import google.generativeai.types as genai_types
import pytest
@ -179,8 +180,13 @@ async def test_function_call(
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": ["test_value"]}
mock_part.function_call = FunctionCall(
name="test_tool",
args={
"param1": ["test_value", "param1\\'s value"],
"param2": "param2\\'s value",
},
)
def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None
@ -220,7 +226,10 @@ async def test_function_call(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": ["test_value"]},
tool_args={
"param1": ["test_value", "param1's value"],
"param2": "param2's value",
},
),
llm.ToolContext(
platform="google_generative_ai_conversation",
@ -279,8 +288,7 @@ async def test_function_exception(
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": 1}
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None