diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index c642bfd94e6..c466101e7e4 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -3,16 +3,17 @@ from __future__ import annotations import codecs -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from dataclasses import replace from typing import Any, Literal, cast -from google.genai.errors import APIError +from google.genai.errors import APIError, ClientError from google.genai.types import ( AutomaticFunctionCallingConfig, Content, FunctionDeclaration, GenerateContentConfig, + GenerateContentResponse, GoogleSearch, HarmCategory, Part, @@ -233,6 +234,81 @@ def _convert_content( return Content(role="model", parts=parts) +async def _transform_stream( + result: AsyncGenerator[GenerateContentResponse], +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + new_message = True + try: + async for response in result: + LOGGER.debug("Received response chunk: %s", response) + chunk: conversation.AssistantContentDeltaDict = {} + + if new_message: + chunk["role"] = "assistant" + new_message = False + + # According to the API docs, this would mean no candidate is returned, so we can safely throw an error here. + if response.prompt_feedback or not response.candidates: + reason = ( + response.prompt_feedback.block_reason_message + if response.prompt_feedback + else "unknown" + ) + raise HomeAssistantError( + f"The message got blocked due to content violations, reason: {reason}" + ) + + candidate = response.candidates[0] + + if ( + candidate.finish_reason is not None + and candidate.finish_reason != "STOP" + ): + # The message ended due to a content error as explained in: https://ai.google.dev/api/generate-content#FinishReason + LOGGER.error( + "Error in Google Generative AI response: %s, see: https://ai.google.dev/api/generate-content#FinishReason", + candidate.finish_reason, + ) + raise HomeAssistantError( + f"{ERROR_GETTING_RESPONSE} Reason: {candidate.finish_reason}" + ) + + response_parts = ( + candidate.content.parts + if candidate.content is not None and candidate.content.parts is not None + else [] + ) + + content = "".join([part.text for part in response_parts if part.text]) + tool_calls = [] + for part in response_parts: + if not part.function_call: + continue + tool_call = part.function_call + tool_name = tool_call.name if tool_call.name else "" + tool_args = _escape_decode(tool_call.args) + tool_calls.append( + llm.ToolInput(tool_name=tool_name, tool_args=tool_args) + ) + + if tool_calls: + chunk["tool_calls"] = tool_calls + + chunk["content"] = content + yield chunk + except ( + APIError, + ValueError, + ) as err: + LOGGER.error("Error sending message: %s %s", type(err), err) + if isinstance(err, APIError): + message = err.message + else: + message = type(err).__name__ + error = f"{ERROR_GETTING_RESPONSE}: {message}" + raise HomeAssistantError(error) from err + + class GoogleGenerativeAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -240,6 +316,7 @@ class GoogleGenerativeAIConversationEntity( _attr_has_entity_name = True _attr_name = None + _attr_supports_streaming = True def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" @@ -426,80 +503,40 @@ class GoogleGenerativeAIConversationEntity( # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: - chat_response = await chat.send_message(message=chat_request) - - if chat_response.prompt_feedback: - raise HomeAssistantError( - f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}" - ) - if not chat_response.candidates: - LOGGER.error( - "No candidates found in the response: %s", - chat_response, - ) - raise HomeAssistantError(ERROR_GETTING_RESPONSE) - + chat_response_generator = await chat.send_message_stream( + message=chat_request + ) except ( APIError, + ClientError, ValueError, ) as err: LOGGER.error("Error sending message: %s %s", type(err), err) - error = f"Sorry, I had a problem talking to Google Generative AI: {err}" + error = ERROR_GETTING_RESPONSE raise HomeAssistantError(error) from err - if (usage_metadata := chat_response.usage_metadata) is not None: - chat_log.async_trace( - { - "stats": { - "input_tokens": usage_metadata.prompt_token_count, - "cached_input_tokens": usage_metadata.cached_content_token_count - or 0, - "output_tokens": usage_metadata.candidates_token_count, - } - } - ) - - response_parts = chat_response.candidates[0].content.parts - if not response_parts: - raise HomeAssistantError(ERROR_GETTING_RESPONSE) - content = " ".join( - [part.text.strip() for part in response_parts if part.text] - ) - - tool_calls = [] - for part in response_parts: - if not part.function_call: - continue - tool_call = part.function_call - tool_name = tool_call.name - tool_args = _escape_decode(tool_call.args) - tool_calls.append( - llm.ToolInput( - tool_name=self._fix_tool_name(tool_name), - tool_args=tool_args, - ) - ) - chat_request = _create_google_tool_response_parts( [ - tool_response - async for tool_response in chat_log.async_add_assistant_content( - conversation.AssistantContent( - agent_id=user_input.agent_id, - content=content, - tool_calls=tool_calls or None, - ) + content + async for content in chat_log.async_add_delta_content_stream( + user_input.agent_id, + _transform_stream(chat_response_generator), ) + if isinstance(content, conversation.ToolResultContent) ] ) - if not tool_calls: + if not chat_log.unresponded_tool_results: break response = intent.IntentResponse(language=user_input.language) - response.async_set_speech( - " ".join([part.text.strip() for part in response_parts if part.text]) - ) + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + LOGGER.error( + "Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response", + chat_log.content[-1], + ) + raise HomeAssistantError(f"{ERROR_GETTING_RESPONSE}") + response.async_set_speech(chat_log.content[-1].content or "") return conversation.ConversationResult( response=response, conversation_id=chat_log.conversation_id, diff --git a/tests/components/google_generative_ai_conversation/__init__.py b/tests/components/google_generative_ai_conversation/__init__.py index fbf9ee545db..18b3c8e07f0 100644 --- a/tests/components/google_generative_ai_conversation/__init__.py +++ b/tests/components/google_generative_ai_conversation/__init__.py @@ -2,10 +2,10 @@ from unittest.mock import Mock -from google.genai.errors import ClientError +from google.genai.errors import APIError, ClientError import httpx -CLIENT_ERROR_500 = ClientError( +API_ERROR_500 = APIError( 500, Mock( __class__=httpx.Response, @@ -17,6 +17,18 @@ CLIENT_ERROR_500 = ClientError( ), ), ) +CLIENT_ERROR_BAD_REQUEST = ClientError( + 400, + Mock( + __class__=httpx.Response, + json=Mock( + return_value={ + "message": "Bad Request", + "status": "invalid-argument", + } + ), + ), +) CLIENT_ERROR_API_KEY_INVALID = ClientError( 400, Mock( diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr deleted file mode 100644 index ce257e61d53..00000000000 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ /dev/null @@ -1,100 +0,0 @@ -# serializer version: 1 -# name: test_function_call - list([ - tuple( - '', - tuple( - ), - dict({ - 'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=)}, property_ordering=None, required=[], type=)}, property_ordering=None, required=[], type=))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None), - 'history': list([ - ]), - 'model': 'models/gemini-2.0-flash', - }), - ), - tuple( - '().send_message', - tuple( - ), - dict({ - 'message': 'Please call the test function', - }), - ), - tuple( - '().send_message', - tuple( - ), - dict({ - 'message': list([ - Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None), - ]), - }), - ), - ]) -# --- -# name: test_function_call_without_parameters - list([ - tuple( - '', - tuple( - ), - dict({ - 'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None), - 'history': list([ - ]), - 'model': 'models/gemini-2.0-flash', - }), - ), - tuple( - '().send_message', - tuple( - ), - dict({ - 'message': 'Please call the test function', - }), - ), - tuple( - '().send_message', - tuple( - ), - dict({ - 'message': list([ - Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None), - ]), - }), - ), - ]) -# --- -# name: test_use_google_search - list([ - tuple( - '', - tuple( - ), - dict({ - 'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=)}, property_ordering=None, required=[], type=)}, property_ordering=None, required=[], type=))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None), - 'history': list([ - ]), - 'model': 'models/gemini-2.0-flash', - }), - ), - tuple( - '().send_message', - tuple( - ), - dict({ - 'message': 'Please call the test function', - }), - ), - tuple( - '().send_message', - tuple( - ), - dict({ - 'message': list([ - Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None), - ]), - }), - ), - ]) -# --- diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 13063580c95..4234355cb5b 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -34,7 +34,7 @@ from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType -from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID +from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID from tests.common import MockConfigEntry @@ -339,7 +339,7 @@ async def test_options_switching( ("side_effect", "error"), [ ( - CLIENT_ERROR_500, + API_ERROR_500, "cannot_connect", ), ( diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 75cb308d5de..2d1a46393fd 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -1,16 +1,14 @@ """Tests for the Google Generative AI Conversation integration conversation platform.""" -from typing import Any -from unittest.mock import AsyncMock, Mock, patch +from collections.abc import Generator +from unittest.mock import AsyncMock, patch from freezegun import freeze_time -from google.genai.types import FunctionCall +from google.genai.types import GenerateContentResponse import pytest -from syrupy.assertion import SnapshotAssertion -import voluptuous as vol from homeassistant.components import conversation -from homeassistant.components.conversation import UserContent, async_get_chat_log, trace +from homeassistant.components.conversation import UserContent from homeassistant.components.google_generative_ai_conversation.conversation import ( ERROR_GETTING_RESPONSE, _escape_decode, @@ -18,12 +16,15 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp ) from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant -from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import chat_session, intent, llm +from homeassistant.helpers import intent -from . import CLIENT_ERROR_500 +from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST from tests.common import MockConfigEntry +from tests.components.conversation import ( + MockChatLog, + mock_chat_log, # noqa: F401 +) @pytest.fixture(autouse=True) @@ -40,396 +41,44 @@ def mock_ulid_tools(): yield -@patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" +@pytest.fixture +def mock_send_message_stream() -> Generator[AsyncMock]: + """Mock stream response.""" + + async def mock_generator(stream): + for value in stream: + yield value + + with patch( + "google.genai.chats.AsyncChat.send_message_stream", + AsyncMock(), + ) as mock_send_message_stream: + mock_send_message_stream.side_effect = lambda **kwargs: mock_generator( + mock_send_message_stream.return_value.pop(0) + ) + + yield mock_send_message_stream + + +@pytest.mark.parametrize( + ("error"), + [ + (API_ERROR_500,), + (CLIENT_ERROR_BAD_REQUEST,), + ], ) -@pytest.mark.usefixtures("mock_init_component") -@pytest.mark.usefixtures("mock_ulid_tools") -async def test_function_call( - mock_get_tools, - hass: HomeAssistant, - mock_config_entry_with_assist: MockConfigEntry, - snapshot: SnapshotAssertion, -) -> None: - """Test function calling.""" - agent_id = "conversation.google_generative_ai_conversation" - context = Context() - - mock_tool = AsyncMock() - mock_tool.name = "test_tool" - mock_tool.description = "Test function" - mock_tool.parameters = vol.Schema( - { - vol.Optional("param1", description="Test parameters"): [ - vol.All(str, vol.Lower) - ], - vol.Optional("param2"): vol.Any(float, int), - vol.Optional("param3"): dict, - } - ) - - mock_get_tools.return_value = [mock_tool] - - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - mock_part = Mock() - mock_part.text = "" - mock_part.function_call = FunctionCall( - name="test_tool", - args={ - "param1": ["test_value", "param1\\'s value"], - "param2": 2.7, - }, - ) - - def tool_call( - hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext - ) -> dict[str, Any]: - mock_part.function_call = None - mock_part.text = "Hi there!" - return {"result": "Test response"} - - mock_tool.async_call.side_effect = tool_call - chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] - result = await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - device_id="test_device", - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - mock_tool_response_parts = mock_create.mock_calls[2][2]["message"] - assert len(mock_tool_response_parts) == 1 - assert mock_tool_response_parts[0].model_dump() == { - "code_execution_result": None, - "executable_code": None, - "file_data": None, - "function_call": None, - "function_response": { - "id": None, - "name": "test_tool", - "response": { - "result": "Test response", - }, - }, - "inline_data": None, - "text": None, - "thought": None, - "video_metadata": None, - } - - mock_tool.async_call.assert_awaited_once_with( - hass, - llm.ToolInput( - id="mock-tool-call", - tool_name="test_tool", - tool_args={ - "param1": ["test_value", "param1's value"], - "param2": 2.7, - }, - ), - llm.LLMContext( - platform="google_generative_ai_conversation", - context=context, - user_prompt="Please call the test function", - language="en", - assistant="conversation", - device_id="test_device", - ), - ) - assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot - - # Test conversating tracing - traces = trace.async_get_traces() - assert traces - last_trace = traces[-1].as_dict() - trace_events = last_trace.get("events", []) - assert [event["event_type"] for event in trace_events] == [ - trace.ConversationTraceEventType.ASYNC_PROCESS, - trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools - trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response - trace.ConversationTraceEventType.TOOL_CALL, - trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response - ] - # AGENT_DETAIL event contains the raw prompt passed to the model - detail_event = trace_events[1] - assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] - assert [ - p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"] - ] == ["test_tool"] - - detail_event = trace_events[2] - assert set(detail_event["data"]["stats"].keys()) == { - "input_tokens", - "cached_input_tokens", - "output_tokens", - } - - -@patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" -) -@pytest.mark.usefixtures("mock_init_component") -@pytest.mark.usefixtures("mock_ulid_tools") -async def test_use_google_search( - mock_get_tools, - hass: HomeAssistant, - mock_config_entry_with_google_search: MockConfigEntry, - snapshot: SnapshotAssertion, -) -> None: - """Test function calling.""" - agent_id = "conversation.google_generative_ai_conversation" - context = Context() - - mock_tool = AsyncMock() - mock_tool.name = "test_tool" - mock_tool.description = "Test function" - mock_tool.parameters = vol.Schema( - { - vol.Optional("param1", description="Test parameters"): [ - vol.All(str, vol.Lower) - ], - vol.Optional("param2"): vol.Any(float, int), - vol.Optional("param3"): dict, - } - ) - - mock_get_tools.return_value = [mock_tool] - - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - mock_part = Mock() - mock_part.text = "" - mock_part.function_call = FunctionCall( - name="test_tool", - args={ - "param1": ["test_value", "param1\\'s value"], - "param2": 2.7, - }, - ) - - def tool_call( - hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext - ) -> dict[str, Any]: - mock_part.function_call = None - mock_part.text = "Hi there!" - return {"result": "Test response"} - - mock_tool.async_call.side_effect = tool_call - chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] - await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - device_id="test_device", - ) - - assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot - - -@patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" -) -@pytest.mark.usefixtures("mock_init_component") -async def test_function_call_without_parameters( - mock_get_tools, - hass: HomeAssistant, - mock_config_entry_with_assist: MockConfigEntry, - snapshot: SnapshotAssertion, -) -> None: - """Test function calling without parameters.""" - agent_id = "conversation.google_generative_ai_conversation" - context = Context() - - mock_tool = AsyncMock() - mock_tool.name = "test_tool" - mock_tool.description = "Test function" - mock_tool.parameters = vol.Schema({}) - - mock_get_tools.return_value = [mock_tool] - - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - mock_part = Mock() - mock_part.text = "" - mock_part.function_call = FunctionCall(name="test_tool", args={}) - - def tool_call( - hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext - ) -> dict[str, Any]: - mock_part.function_call = None - mock_part.text = "Hi there!" - return {"result": "Test response"} - - mock_tool.async_call.side_effect = tool_call - chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] - result = await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - device_id="test_device", - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - mock_tool_response_parts = mock_create.mock_calls[2][2]["message"] - assert len(mock_tool_response_parts) == 1 - assert mock_tool_response_parts[0].model_dump() == { - "code_execution_result": None, - "executable_code": None, - "file_data": None, - "function_call": None, - "function_response": { - "id": None, - "name": "test_tool", - "response": { - "result": "Test response", - }, - }, - "inline_data": None, - "text": None, - "thought": None, - "video_metadata": None, - } - - mock_tool.async_call.assert_awaited_once_with( - hass, - llm.ToolInput( - id="mock-tool-call", - tool_name="test_tool", - tool_args={}, - ), - llm.LLMContext( - platform="google_generative_ai_conversation", - context=context, - user_prompt="Please call the test function", - language="en", - assistant="conversation", - device_id="test_device", - ), - ) - assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot - - -@patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" -) -@pytest.mark.usefixtures("mock_init_component") -async def test_function_exception( - mock_get_tools, - hass: HomeAssistant, - mock_config_entry_with_assist: MockConfigEntry, -) -> None: - """Test exception in function calling.""" - agent_id = "conversation.google_generative_ai_conversation" - context = Context() - - mock_tool = AsyncMock() - mock_tool.name = "test_tool" - mock_tool.description = "Test function" - mock_tool.parameters = vol.Schema( - { - vol.Optional("param1", description="Test parameters"): vol.All( - vol.Coerce(int), vol.Range(0, 100) - ) - } - ) - - mock_get_tools.return_value = [mock_tool] - - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - mock_part = Mock() - mock_part.text = "" - mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1}) - - def tool_call( - hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext - ) -> dict[str, Any]: - mock_part.function_call = None - mock_part.text = "Hi there!" - raise HomeAssistantError("Test tool exception") - - mock_tool.async_call.side_effect = tool_call - chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] - result = await conversation.async_converse( - hass, - "Please call the test function", - None, - context, - agent_id=agent_id, - device_id="test_device", - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - mock_tool_response_parts = mock_create.mock_calls[2][2]["message"] - assert len(mock_tool_response_parts) == 1 - assert mock_tool_response_parts[0].model_dump() == { - "code_execution_result": None, - "executable_code": None, - "file_data": None, - "function_call": None, - "function_response": { - "id": None, - "name": "test_tool", - "response": { - "error": "HomeAssistantError", - "error_text": "Test tool exception", - }, - }, - "inline_data": None, - "text": None, - "thought": None, - "video_metadata": None, - } - mock_tool.async_call.assert_awaited_once_with( - hass, - llm.ToolInput( - id="mock-tool-call", - tool_name="test_tool", - tool_args={"param1": 1}, - ), - llm.LLMContext( - platform="google_generative_ai_conversation", - context=context, - user_prompt="Please call the test function", - language="en", - assistant="conversation", - device_id="test_device", - ), - ) - - -@pytest.mark.usefixtures("mock_init_component") async def test_error_handling( - hass: HomeAssistant, mock_config_entry: MockConfigEntry + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + error, ) -> None: """Test that client errors are caught.""" - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - mock_chat.side_effect = CLIENT_ERROR_500 + with patch( + "google.genai.chats.AsyncChat.send_message_stream", + new_callable=AsyncMock, + side_effect=error, + ): result = await conversation.async_converse( hass, "hello", @@ -437,32 +86,251 @@ async def test_error_handling( Context(), agent_id="conversation.google_generative_ai_conversation", ) - assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result - assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}" + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] == ERROR_GETTING_RESPONSE ) +@pytest.mark.usefixtures("mock_init_component") +@pytest.mark.usefixtures("mock_ulid_tools") +async def test_function_call( + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, +) -> None: + """Test function calling.""" + agent_id = "conversation.google_generative_ai_conversation" + context = Context() + + messages = [ + # Function call stream + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "text": "Hi there!", + } + ], + "role": "model", + } + } + ] + ), + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "function_call": { + "name": "test_tool", + "args": { + "param1": [ + "test_value", + "param1\\'s value", + ], + "param2": 2.7, + }, + }, + } + ], + "role": "model", + }, + "finish_reason": "STOP", + } + ] + ), + ], + # Messages after function response is sent + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "text": "I've called the ", + } + ], + "role": "model", + }, + } + ], + ), + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "text": "test function with the provided parameters.", + } + ], + "role": "model", + }, + "finish_reason": "STOP", + } + ], + ), + ], + ] + + mock_send_message_stream.return_value = messages + + mock_chat_log.mock_tool_results( + { + "mock-tool-call": {"result": "Test response"}, + } + ) + + result = await conversation.async_converse( + hass, + "Please call the test function", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] + == "I've called the test function with the provided parameters." + ) + mock_tool_response_parts = mock_send_message_stream.mock_calls[1][2]["message"] + assert len(mock_tool_response_parts) == 1 + assert mock_tool_response_parts[0].model_dump() == { + "code_execution_result": None, + "executable_code": None, + "file_data": None, + "function_call": None, + "function_response": { + "id": None, + "name": "test_tool", + "response": { + "result": "Test response", + }, + }, + "inline_data": None, + "text": None, + "thought": None, + "video_metadata": None, + } + + +@pytest.mark.usefixtures("mock_init_component") +@pytest.mark.usefixtures("mock_ulid_tools") +async def test_google_search_tool_is_sent( + hass: HomeAssistant, + mock_config_entry_with_google_search: MockConfigEntry, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, +) -> None: + """Test if the Google Search tool is sent to the model.""" + agent_id = "conversation.google_generative_ai_conversation" + context = Context() + + messages = [ + # Messages from the model which contain the google search answer (the usage of the Google Search tool is server side) + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "text": "The last winner ", + } + ], + "role": "model", + }, + } + ], + ), + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + {"text": "of the 2024 FIFA World Cup was Argentina."} + ], + "role": "model", + }, + "finish_reason": "STOP", + } + ], + ), + ], + ] + + mock_send_message_stream.return_value = messages + + with patch( + "google.genai.chats.AsyncChats.create", return_value=AsyncMock() + ) as mock_create: + mock_create.return_value.send_message_stream = mock_send_message_stream + result = await conversation.async_converse( + hass, + "Who won the 2024 FIFA World Cup?", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] + == "The last winner of the 2024 FIFA World Cup was Argentina." + ) + assert mock_create.mock_calls[0][2]["config"].tools[-1].google_search is not None + + @pytest.mark.usefixtures("mock_init_component") async def test_blocked_response( - hass: HomeAssistant, mock_config_entry: MockConfigEntry + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, ) -> None: """Test blocked response.""" - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY")) - mock_chat.return_value = chat_response + agent_id = "conversation.google_generative_ai_conversation" + context = Context() - result = await conversation.async_converse( - hass, - "hello", - None, - Context(), - agent_id="conversation.google_generative_ai_conversation", - ) + messages = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "text": "I've called the ", + } + ], + "role": "model", + }, + } + ], + ), + GenerateContentResponse(prompt_feedback={"block_reason_message": "SAFETY"}), + ], + ] + + mock_send_message_stream.return_value = messages + + result = await conversation.async_converse( + hass, + "Please call the test function", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result @@ -473,23 +341,41 @@ async def test_blocked_response( @pytest.mark.usefixtures("mock_init_component") async def test_empty_response( - hass: HomeAssistant, mock_config_entry: MockConfigEntry + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, ) -> None: """Test empty response.""" - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - chat_response.candidates = [Mock(content=Mock(parts=[]))] - result = await conversation.async_converse( - hass, - "hello", - None, - Context(), - agent_id="conversation.google_generative_ai_conversation", - ) + agent_id = "conversation.google_generative_ai_conversation" + context = Context() + + messages = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [], + "role": "model", + }, + } + ], + ), + ], + ] + + mock_send_message_stream.return_value = messages + + result = await conversation.async_converse( + hass, + "Hello", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result assert result.response.as_dict()["speech"]["plain"]["speech"] == ( @@ -499,27 +385,36 @@ async def test_empty_response( @pytest.mark.usefixtures("mock_init_component") async def test_none_response( - hass: HomeAssistant, mock_config_entry: MockConfigEntry + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, ) -> None: - """Test empty response.""" - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - chat_response.candidates = None - result = await conversation.async_converse( - hass, - "hello", - None, - Context(), - agent_id="conversation.google_generative_ai_conversation", - ) + """Test None response.""" + agent_id = "conversation.google_generative_ai_conversation" + context = Context() + + messages = [ + [ + GenerateContentResponse(), + ], + ] + + mock_send_message_stream.return_value = messages + + result = await conversation.async_converse( + hass, + "Hello", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", + ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - ERROR_GETTING_RESPONSE + "The message got blocked due to content violations, reason: unknown" ) @@ -712,69 +607,109 @@ async def test_format_schema(openapi, genai_schema) -> None: @pytest.mark.usefixtures("mock_init_component") async def test_empty_content_in_chat_history( - hass: HomeAssistant, mock_config_entry: MockConfigEntry + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, ) -> None: """Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead.""" - with ( - patch("google.genai.chats.AsyncChats.create") as mock_create, - chat_session.async_get_chat_session(hass) as session, - async_get_chat_log(hass, session) as chat_log, - ): - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat + agent_id = "conversation.google_generative_ai_conversation" + context = Context() - # Chat preparation with two inputs, one being an empty string - first_input = "First request" - second_input = "" - chat_log.async_add_user_content(UserContent(first_input)) - chat_log.async_add_user_content(UserContent(second_input)) + messages = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [{"text": "Hi there!"}], + "role": "model", + }, + } + ], + ), + ], + ] + mock_send_message_stream.return_value = messages + + # Chat preparation with two inputs, one being an empty string + first_input = "First request" + second_input = "" + mock_chat_log.async_add_user_content(UserContent(first_input)) + mock_chat_log.async_add_user_content(UserContent(second_input)) + + with patch( + "google.genai.chats.AsyncChats.create", return_value=AsyncMock() + ) as mock_create: + mock_create.return_value.send_message_stream = mock_send_message_stream await conversation.async_converse( hass, - "Second request", - session.conversation_id, - Context(), - agent_id="conversation.google_generative_ai_conversation", + "Hello", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", ) - _, kwargs = mock_create.call_args - actual_history = kwargs.get("history") + _, kwargs = mock_create.call_args + actual_history = kwargs.get("history") - assert actual_history[0].parts[0].text == first_input - assert actual_history[1].parts[0].text == " " + assert actual_history[0].parts[0].text == first_input + assert actual_history[1].parts[0].text == " " @pytest.mark.usefixtures("mock_init_component") async def test_history_always_user_first_turn( hass: HomeAssistant, mock_config_entry: MockConfigEntry, - snapshot: SnapshotAssertion, + mock_chat_log: MockChatLog, # noqa: F811 + mock_send_message_stream: AsyncMock, ) -> None: """Test that the user is always first in the chat history.""" - with ( - chat_session.async_get_chat_session(hass) as session, - async_get_chat_log(hass, session) as chat_log, - ): - chat_log.async_add_assistant_content_without_tools( - conversation.AssistantContent( - agent_id="conversation.google_generative_ai_conversation", - content="Garage door left open, do you want to close it?", - ) + + agent_id = "conversation.google_generative_ai_conversation" + context = Context() + + messages = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [ + { + "text": " Yes, I can help with that. ", + } + ], + "role": "model", + }, + } + ], + ), + ], + ] + + mock_send_message_stream.return_value = messages + + mock_chat_log.async_add_assistant_content_without_tools( + conversation.AssistantContent( + agent_id="conversation.google_generative_ai_conversation", + content="Garage door left open, do you want to close it?", ) + ) - with patch("google.genai.chats.AsyncChats.create") as mock_create: - mock_chat = AsyncMock() - mock_create.return_value.send_message = mock_chat - chat_response = Mock(prompt_feedback=None) - mock_chat.return_value = chat_response - chat_response.candidates = [Mock(content=Mock(parts=[]))] - + with patch( + "google.genai.chats.AsyncChats.create", return_value=AsyncMock() + ) as mock_create: + mock_create.return_value.send_message_stream = mock_send_message_stream await conversation.async_converse( hass, - "hello", - chat_log.conversation_id, - Context(), - agent_id="conversation.google_generative_ai_conversation", + "Hello", + mock_chat_log.conversation_id, + context, + agent_id=agent_id, + device_id="test_device", ) _, kwargs = mock_create.call_args diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 94308260f74..6cc0bdd5f44 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -11,7 +11,7 @@ from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID +from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID from tests.common import MockConfigEntry @@ -212,7 +212,7 @@ async def test_generate_content_service_error( with ( patch( "google.genai.models.AsyncModels.generate_content", - side_effect=CLIENT_ERROR_500, + side_effect=API_ERROR_500, ), pytest.raises( HomeAssistantError, @@ -311,7 +311,7 @@ async def test_generate_content_service_with_image_not_exists( ("side_effect", "state", "reauth"), [ ( - CLIENT_ERROR_500, + API_ERROR_500, ConfigEntryState.SETUP_ERROR, False, ),