mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Fix Google Generative AI: 400 Request contains an invalid argument (#120741)
This commit is contained in:
parent
c5fa9ad272
commit
cada78496b
@ -95,9 +95,12 @@ def _format_tool(
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Format tool specification."""
|
"""Format tool specification."""
|
||||||
|
|
||||||
parameters = _format_schema(
|
if tool.parameters.schema:
|
||||||
convert(tool.parameters, custom_serializer=custom_serializer)
|
parameters = _format_schema(
|
||||||
)
|
convert(tool.parameters, custom_serializer=custom_serializer)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parameters = None
|
||||||
|
|
||||||
return protos.Tool(
|
return protos.Tool(
|
||||||
{
|
{
|
||||||
|
@ -409,3 +409,169 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_function_call
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||||
|
''',
|
||||||
|
'tools': list([
|
||||||
|
function_declarations {
|
||||||
|
name: "test_tool"
|
||||||
|
description: "Test function"
|
||||||
|
parameters {
|
||||||
|
type_: OBJECT
|
||||||
|
properties {
|
||||||
|
key: "param1"
|
||||||
|
value {
|
||||||
|
type_: ARRAY
|
||||||
|
description: "Test parameters"
|
||||||
|
items {
|
||||||
|
type_: STRING
|
||||||
|
format_: "lower"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
,
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'Please call the test function',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
parts {
|
||||||
|
function_response {
|
||||||
|
name: "test_tool"
|
||||||
|
response {
|
||||||
|
fields {
|
||||||
|
key: "result"
|
||||||
|
value {
|
||||||
|
string_value: "Test response"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
,
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_function_call_without_parameters
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'system_instruction': '''
|
||||||
|
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||||
|
''',
|
||||||
|
'tools': list([
|
||||||
|
function_declarations {
|
||||||
|
name: "test_tool"
|
||||||
|
description: "Test function"
|
||||||
|
}
|
||||||
|
,
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'Please call the test function',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
parts {
|
||||||
|
function_response {
|
||||||
|
name: "test_tool"
|
||||||
|
response {
|
||||||
|
fields {
|
||||||
|
key: "result"
|
||||||
|
value {
|
||||||
|
string_value: "Test response"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
,
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
@ -172,6 +172,7 @@ async def test_function_call(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test function calling."""
|
"""Test function calling."""
|
||||||
agent_id = mock_config_entry_with_assist.entry_id
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
@ -256,6 +257,7 @@ async def test_function_call(
|
|||||||
device_id="test_device",
|
device_id="test_device",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||||
|
|
||||||
# Test conversating tracing
|
# Test conversating tracing
|
||||||
traces = trace.async_get_traces()
|
traces = trace.async_get_traces()
|
||||||
@ -272,6 +274,87 @@ async def test_function_call(
|
|||||||
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
assert "Answer in plain text" in detail_event["data"]["prompt"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
|
)
|
||||||
|
async def test_function_call_without_parameters(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test function calling without parameters."""
|
||||||
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema({})
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
|
chat_response = MagicMock()
|
||||||
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
|
mock_part = MagicMock()
|
||||||
|
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
||||||
|
|
||||||
|
def tool_call(hass, tool_input, tool_context):
|
||||||
|
mock_part.function_call = None
|
||||||
|
mock_part.text = "Hi there!"
|
||||||
|
return {"result": "Test response"}
|
||||||
|
|
||||||
|
mock_tool.async_call.side_effect = tool_call
|
||||||
|
chat_response.parts = [mock_part]
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||||
|
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
|
||||||
|
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
|
||||||
|
assert mock_tool_call == {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"response": {
|
||||||
|
"result": "Test response",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"role": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
|
hass,
|
||||||
|
llm.ToolInput(
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={},
|
||||||
|
),
|
||||||
|
llm.LLMContext(
|
||||||
|
platform="google_generative_ai_conversation",
|
||||||
|
context=context,
|
||||||
|
user_prompt="Please call the test function",
|
||||||
|
language="en",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id="test_device",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user