mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Fix unnecessary single quotes escaping in Google AI (#118522)
This commit is contained in:
parent
0d6c7d0973
commit
272c51fb38
@ -8,6 +8,7 @@ import google.ai.generativelanguage as glm
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.types as genai_types
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
@ -105,6 +106,17 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _adjust_value(value: Any) -> Any:
|
||||
"""Reverse unnecessary single quotes escaping."""
|
||||
if isinstance(value, str):
|
||||
return value.replace("\\'", "'")
|
||||
if isinstance(value, list):
|
||||
return [_adjust_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _adjust_value(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -295,21 +307,22 @@ class GoogleGenerativeAIConversationEntity(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
self.history[conversation_id] = chat.history
|
||||
tool_calls = [
|
||||
function_calls = [
|
||||
part.function_call for part in chat_response.parts if part.function_call
|
||||
]
|
||||
if not tool_calls or not llm_api:
|
||||
if not function_calls or not llm_api:
|
||||
break
|
||||
|
||||
tool_responses = []
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.name,
|
||||
tool_args=dict(tool_call.args),
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
for function_call in function_calls:
|
||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = {
|
||||
key: _adjust_value(value)
|
||||
for key, value in tool_call["args"].items()
|
||||
}
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
try:
|
||||
function_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
@ -321,7 +334,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
tool_responses.append(
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=tool_call.name, response=function_response
|
||||
name=tool_name, response=function_response
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -140,7 +140,7 @@ class APIInstance:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
async_conversation_trace_append(
|
||||
ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
{"tool_name": tool_input.tool_name, "tool_args": str(tool_input.tool_args)},
|
||||
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
||||
)
|
||||
|
||||
for tool in self.tools:
|
||||
|
@ -3,6 +3,7 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from google.ai.generativelanguage_v1beta.types.content import FunctionCall
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
import google.generativeai.types as genai_types
|
||||
import pytest
|
||||
@ -179,8 +180,13 @@ async def test_function_call(
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call.name = "test_tool"
|
||||
mock_part.function_call.args = {"param1": ["test_value"]}
|
||||
mock_part.function_call = FunctionCall(
|
||||
name="test_tool",
|
||||
args={
|
||||
"param1": ["test_value", "param1\\'s value"],
|
||||
"param2": "param2\\'s value",
|
||||
},
|
||||
)
|
||||
|
||||
def tool_call(hass, tool_input, tool_context):
|
||||
mock_part.function_call = None
|
||||
@ -220,7 +226,10 @@ async def test_function_call(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": ["test_value"]},
|
||||
tool_args={
|
||||
"param1": ["test_value", "param1's value"],
|
||||
"param2": "param2's value",
|
||||
},
|
||||
),
|
||||
llm.ToolContext(
|
||||
platform="google_generative_ai_conversation",
|
||||
@ -279,8 +288,7 @@ async def test_function_exception(
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call.name = "test_tool"
|
||||
mock_part.function_call.args = {"param1": 1}
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
||||
|
||||
def tool_call(hass, tool_input, tool_context):
|
||||
mock_part.function_call = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user