diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index d6f7981fc8c..33dade8bf29 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -298,43 +298,47 @@ class GoogleGenerativeAIConversationEntity( response=intent_response, conversation_id=conversation_id ) self.history[conversation_id] = chat.history - tool_call = chat_response.parts[0].function_call - - if not tool_call or not llm_api: + tool_calls = [ + part.function_call for part in chat_response.parts if part.function_call + ] + if not tool_calls or not llm_api: break - tool_input = llm.ToolInput( - tool_name=tool_call.name, - tool_args=dict(tool_call.args), - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - LOGGER.debug( - "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args - ) - try: - function_response = await llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as e: - function_response = {"error": type(e).__name__} - if str(e): - function_response["error_text"] = str(e) + tool_responses = [] + for tool_call in tool_calls: + tool_input = llm.ToolInput( + tool_name=tool_call.name, + tool_args=dict(tool_call.args), + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + LOGGER.debug( + "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args + ) + try: + function_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + function_response = {"error": type(e).__name__} + if str(e): + function_response["error_text"] = str(e) - LOGGER.debug("Tool response: %s", function_response) - chat_request = glm.Content( - parts=[ + LOGGER.debug("Tool response: %s", function_response) + tool_responses.append( glm.Part( function_response=glm.FunctionResponse( name=tool_call.name, response=function_response ) ) - ] - ) + ) + chat_request = glm.Content(parts=tool_responses) - intent_response.async_set_speech(chat_response.text) + intent_response.async_set_speech( + " ".join([part.text for part in chat_response.parts if part.text]) + ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index ad169d9ae0d..284bd904b44 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -191,8 +191,8 @@ async def test_default_prompt( mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() mock_part.function_call = None + mock_part.text = "Hi there!" chat_response.parts = [mock_part] - chat_response.text = "Hi there!" result = await conversation.async_converse( hass, "hello", @@ -221,8 +221,8 @@ async def test_chat_history( mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() mock_part.function_call = None + mock_part.text = "1st model response" chat_response.parts = [mock_part] - chat_response.text = "1st model response" mock_chat.history = [ {"role": "user", "parts": "prompt"}, {"role": "model", "parts": "Ok"}, @@ -241,7 +241,8 @@ async def test_chat_history( result.response.as_dict()["speech"]["plain"]["speech"] == "1st model response" ) - chat_response.text = "2nd model response" + mock_part.text = "2nd model response" + chat_response.parts = [mock_part] result = await conversation.async_converse( hass, "2nd user request", @@ -294,8 +295,8 @@ async def test_function_call( mock_part.function_call.args = {"param1": ["test_value"]} def tool_call(hass, tool_input): - mock_part.function_call = False - chat_response.text = "Hi there!" + mock_part.function_call = None + mock_part.text = "Hi there!" return {"result": "Test response"} mock_tool.async_call.side_effect = tool_call @@ -392,8 +393,8 @@ async def test_function_exception( mock_part.function_call.args = {"param1": 1} def tool_call(hass, tool_input): - mock_part.function_call = False - chat_response.text = "Hi there!" + mock_part.function_call = None + mock_part.text = "Hi there!" raise HomeAssistantError("Test tool exception") mock_tool.async_call.side_effect = tool_call