diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index fb7f5c3b21c..8052ee66f40 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -95,9 +95,12 @@ def _format_tool( ) -> dict[str, Any]: """Format tool specification.""" - parameters = _format_schema( - convert(tool.parameters, custom_serializer=custom_serializer) - ) + if tool.parameters.schema: + parameters = _format_schema( + convert(tool.parameters, custom_serializer=custom_serializer) + ) + else: + parameters = None return protos.Tool( { diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index b0a0ce967de..7f28c172970 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -409,3 +409,169 @@ ), ]) # --- +# name: test_function_call + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant. + ''', + 'tools': list([ + function_declarations { + name: "test_tool" + description: "Test function" + parameters { + type_: OBJECT + properties { + key: "param1" + value { + type_: ARRAY + description: "Test parameters" + items { + type_: STRING + format_: "lower" + } + } + } + } + } + , + ]), + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'Please call the test function', + ), + dict({ + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + parts { + function_response { + name: "test_tool" + response { + fields { + key: "result" + value { + string_value: "Test response" + } + } + } + } + } + , + ), + dict({ + }), + ), + ]) +# --- +# name: test_function_call_without_parameters + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant. + ''', + 'tools': list([ + function_declarations { + name: "test_tool" + description: "Test function" + } + , + ]), + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'Please call the test function', + ), + dict({ + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + parts { + function_response { + name: "test_tool" + response { + fields { + key: "result" + value { + string_value: "Test response" + } + } + } + } + } + , + ), + dict({ + }), + ), + ]) +# --- diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 990058aa89d..30016335f3b 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -172,6 +172,7 @@ async def test_function_call( hass: HomeAssistant, mock_config_entry_with_assist: MockConfigEntry, mock_init_component, + snapshot: SnapshotAssertion, ) -> None: """Test function calling.""" agent_id = mock_config_entry_with_assist.entry_id @@ -256,6 +257,7 @@ async def test_function_call( device_id="test_device", ), ) + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot # Test conversating tracing traces = trace.async_get_traces() @@ -272,6 +274,87 @@ async def test_function_call( assert "Answer in plain text" in detail_event["data"]["prompt"] +@patch( + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" +) +async def test_function_call_without_parameters( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test function calling without parameters.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema({}) + + mock_get_tools.return_value = [mock_tool] + + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + mock_part = MagicMock() + mock_part.function_call = FunctionCall(name="test_tool", args={}) + + def tool_call(hass, tool_input, tool_context): + mock_part.function_call = None + mock_part.text = "Hi there!" + return {"result": "Test response"} + + mock_tool.async_call.side_effect = tool_call + chat_response.parts = [mock_part] + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + device_id="test_device", + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" + mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] + mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) + assert mock_tool_call == { + "parts": [ + { + "function_response": { + "name": "test_tool", + "response": { + "result": "Test response", + }, + }, + }, + ], + "role": "", + } + + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={}, + ), + llm.LLMContext( + platform="google_generative_ai_conversation", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id="test_device", + ), + ) + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + @patch( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" )