Fix Google Generative AI: 400 Request contains an invalid argument (#120741)

This commit is contained in:
tronikos 2024-06-28 04:25:55 -07:00 committed by GitHub
parent c385deb6a3
commit d2a457c24f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 255 additions and 3 deletions

View File

@ -95,9 +95,12 @@ def _format_tool(
) -> dict[str, Any]:
"""Format tool specification."""
parameters = _format_schema(
convert(tool.parameters, custom_serializer=custom_serializer)
)
if tool.parameters.schema:
parameters = _format_schema(
convert(tool.parameters, custom_serializer=custom_serializer)
)
else:
parameters = None
return protos.Tool(
{

View File

@ -409,3 +409,169 @@
),
])
# ---
# name: test_function_call
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'tools': list([
function_declarations {
name: "test_tool"
description: "Test function"
parameters {
type_: OBJECT
properties {
key: "param1"
value {
type_: ARRAY
description: "Test parameters"
items {
type_: STRING
format_: "lower"
}
}
}
}
}
,
]),
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'Please call the test function',
),
dict({
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
parts {
function_response {
name: "test_tool"
response {
fields {
key: "result"
value {
string_value: "Test response"
}
}
}
}
}
,
),
dict({
}),
),
])
# ---
# name: test_function_call_without_parameters
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'tools': list([
function_declarations {
name: "test_tool"
description: "Test function"
}
,
]),
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'Please call the test function',
),
dict({
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
parts {
function_response {
name: "test_tool"
response {
fields {
key: "result"
value {
string_value: "Test response"
}
}
}
}
}
,
),
dict({
}),
),
])
# ---

View File

@ -172,6 +172,7 @@ async def test_function_call(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling."""
agent_id = mock_config_entry_with_assist.entry_id
@ -256,6 +257,7 @@ async def test_function_call(
device_id="test_device",
),
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
# Test conversating tracing
traces = trace.async_get_traces()
@ -272,6 +274,87 @@ async def test_function_call(
assert "Answer in plain text" in detail_event["data"]["prompt"]
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
async def test_function_call_without_parameters(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling without parameters."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema({})
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call = FunctionCall(name="test_tool", args={})
def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None
mock_part.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
"parts": [
{
"function_response": {
"name": "test_tool",
"response": {
"result": "Test response",
},
},
},
],
"role": "",
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={},
),
llm.LLMContext(
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id="test_device",
),
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)