diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index f85cf2530dc..e7aaabb912d 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -8,6 +8,7 @@ import google.ai.generativelanguage as glm from google.api_core.exceptions import GoogleAPICallError import google.generativeai as genai import google.generativeai.types as genai_types +from google.protobuf.json_format import MessageToDict import voluptuous as vol from voluptuous_openapi import convert @@ -105,6 +106,17 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]: ) +def _adjust_value(value: Any) -> Any: + """Reverse unnecessary single quotes escaping.""" + if isinstance(value, str): + return value.replace("\\'", "'") + if isinstance(value, list): + return [_adjust_value(item) for item in value] + if isinstance(value, dict): + return {k: _adjust_value(v) for k, v in value.items()} + return value + + class GoogleGenerativeAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -295,21 +307,22 @@ class GoogleGenerativeAIConversationEntity( response=intent_response, conversation_id=conversation_id ) self.history[conversation_id] = chat.history - tool_calls = [ + function_calls = [ part.function_call for part in chat_response.parts if part.function_call ] - if not tool_calls or not llm_api: + if not function_calls or not llm_api: break tool_responses = [] - for tool_call in tool_calls: - tool_input = llm.ToolInput( - tool_name=tool_call.name, - tool_args=dict(tool_call.args), - ) - LOGGER.debug( - "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args - ) + for function_call in function_calls: + tool_call = MessageToDict(function_call._pb) # noqa: SLF001 + tool_name = tool_call["name"] + tool_args = { + key: _adjust_value(value) + for key, value in tool_call["args"].items() + } + LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args) + tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) try: function_response = await llm_api.async_call_tool(tool_input) except (HomeAssistantError, vol.Invalid) as e: @@ -321,7 +334,7 @@ class GoogleGenerativeAIConversationEntity( tool_responses.append( glm.Part( function_response=glm.FunctionResponse( - name=tool_call.name, response=function_response + name=tool_name, response=function_response ) ) ) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 5591c4a8aba..57b72bc9618 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -140,7 +140,7 @@ class APIInstance: """Call a LLM tool, validate args and return the response.""" async_conversation_trace_append( ConversationTraceEventType.LLM_TOOL_CALL, - {"tool_name": tool_input.tool_name, "tool_args": str(tool_input.tool_args)}, + {"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args}, ) for tool in self.tools: diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 4c7f2de5e2e..b282895baef 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from freezegun import freeze_time +from google.ai.generativelanguage_v1beta.types.content import FunctionCall from google.api_core.exceptions import GoogleAPICallError import google.generativeai.types as genai_types import pytest @@ -179,8 +180,13 @@ async def test_function_call( chat_response = MagicMock() mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() - mock_part.function_call.name = "test_tool" - mock_part.function_call.args = {"param1": ["test_value"]} + mock_part.function_call = FunctionCall( + name="test_tool", + args={ + "param1": ["test_value", "param1\\'s value"], + "param2": "param2\\'s value", + }, + ) def tool_call(hass, tool_input, tool_context): mock_part.function_call = None @@ -220,7 +226,10 @@ async def test_function_call( hass, llm.ToolInput( tool_name="test_tool", - tool_args={"param1": ["test_value"]}, + tool_args={ + "param1": ["test_value", "param1's value"], + "param2": "param2's value", + }, ), llm.ToolContext( platform="google_generative_ai_conversation", @@ -279,8 +288,7 @@ async def test_function_exception( chat_response = MagicMock() mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() - mock_part.function_call.name = "test_tool" - mock_part.function_call.args = {"param1": 1} + mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1}) def tool_call(hass, tool_input, tool_context): mock_part.function_call = None