Enable message Streaming in the Gemini integration. (#144937)

* Added streaming implementation

* Indicate the entity supports streaming

* Added tests

* Removed unused snapshots

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Ivan Lopez Hernandez 2025-05-25 18:50:55 -07:00 committed by GitHub
parent e4b519d77a
commit 32eb4af6ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 534 additions and 650 deletions

View File

@ -3,16 +3,17 @@
from __future__ import annotations
import codecs
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
from dataclasses import replace
from typing import Any, Literal, cast
from google.genai.errors import APIError
from google.genai.errors import APIError, ClientError
from google.genai.types import (
AutomaticFunctionCallingConfig,
Content,
FunctionDeclaration,
GenerateContentConfig,
GenerateContentResponse,
GoogleSearch,
HarmCategory,
Part,
@ -233,6 +234,81 @@ def _convert_content(
return Content(role="model", parts=parts)
async def _transform_stream(
result: AsyncGenerator[GenerateContentResponse],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
new_message = True
try:
async for response in result:
LOGGER.debug("Received response chunk: %s", response)
chunk: conversation.AssistantContentDeltaDict = {}
if new_message:
chunk["role"] = "assistant"
new_message = False
# According to the API docs, this would mean no candidate is returned, so we can safely throw an error here.
if response.prompt_feedback or not response.candidates:
reason = (
response.prompt_feedback.block_reason_message
if response.prompt_feedback
else "unknown"
)
raise HomeAssistantError(
f"The message got blocked due to content violations, reason: {reason}"
)
candidate = response.candidates[0]
if (
candidate.finish_reason is not None
and candidate.finish_reason != "STOP"
):
# The message ended due to a content error as explained in: https://ai.google.dev/api/generate-content#FinishReason
LOGGER.error(
"Error in Google Generative AI response: %s, see: https://ai.google.dev/api/generate-content#FinishReason",
candidate.finish_reason,
)
raise HomeAssistantError(
f"{ERROR_GETTING_RESPONSE} Reason: {candidate.finish_reason}"
)
response_parts = (
candidate.content.parts
if candidate.content is not None and candidate.content.parts is not None
else []
)
content = "".join([part.text for part in response_parts if part.text])
tool_calls = []
for part in response_parts:
if not part.function_call:
continue
tool_call = part.function_call
tool_name = tool_call.name if tool_call.name else ""
tool_args = _escape_decode(tool_call.args)
tool_calls.append(
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
)
if tool_calls:
chunk["tool_calls"] = tool_calls
chunk["content"] = content
yield chunk
except (
APIError,
ValueError,
) as err:
LOGGER.error("Error sending message: %s %s", type(err), err)
if isinstance(err, APIError):
message = err.message
else:
message = type(err).__name__
error = f"{ERROR_GETTING_RESPONSE}: {message}"
raise HomeAssistantError(error) from err
class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -240,6 +316,7 @@ class GoogleGenerativeAIConversationEntity(
_attr_has_entity_name = True
_attr_name = None
_attr_supports_streaming = True
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
@ -426,80 +503,40 @@ class GoogleGenerativeAIConversationEntity(
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
chat_response = await chat.send_message(message=chat_request)
if chat_response.prompt_feedback:
raise HomeAssistantError(
f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}"
chat_response_generator = await chat.send_message_stream(
message=chat_request
)
if not chat_response.candidates:
LOGGER.error(
"No candidates found in the response: %s",
chat_response,
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
except (
APIError,
ClientError,
ValueError,
) as err:
LOGGER.error("Error sending message: %s %s", type(err), err)
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
error = ERROR_GETTING_RESPONSE
raise HomeAssistantError(error) from err
if (usage_metadata := chat_response.usage_metadata) is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": usage_metadata.prompt_token_count,
"cached_input_tokens": usage_metadata.cached_content_token_count
or 0,
"output_tokens": usage_metadata.candidates_token_count,
}
}
)
response_parts = chat_response.candidates[0].content.parts
if not response_parts:
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
content = " ".join(
[part.text.strip() for part in response_parts if part.text]
)
tool_calls = []
for part in response_parts:
if not part.function_call:
continue
tool_call = part.function_call
tool_name = tool_call.name
tool_args = _escape_decode(tool_call.args)
tool_calls.append(
llm.ToolInput(
tool_name=self._fix_tool_name(tool_name),
tool_args=tool_args,
)
)
chat_request = _create_google_tool_response_parts(
[
tool_response
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_calls or None,
)
content
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id,
_transform_stream(chat_response_generator),
)
if isinstance(content, conversation.ToolResultContent)
]
)
if not tool_calls:
if not chat_log.unresponded_tool_results:
break
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(
" ".join([part.text.strip() for part in response_parts if part.text])
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(f"{ERROR_GETTING_RESPONSE}")
response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=response,
conversation_id=chat_log.conversation_id,

View File

@ -2,10 +2,10 @@
from unittest.mock import Mock
from google.genai.errors import ClientError
from google.genai.errors import APIError, ClientError
import httpx
CLIENT_ERROR_500 = ClientError(
API_ERROR_500 = APIError(
500,
Mock(
__class__=httpx.Response,
@ -17,6 +17,18 @@ CLIENT_ERROR_500 = ClientError(
),
),
)
CLIENT_ERROR_BAD_REQUEST = ClientError(
400,
Mock(
__class__=httpx.Response,
json=Mock(
return_value={
"message": "Bad Request",
"status": "invalid-argument",
}
),
),
)
CLIENT_ERROR_API_KEY_INVALID = ClientError(
400,
Mock(

View File

@ -1,100 +0,0 @@
# serializer version: 1
# name: test_function_call
list([
tuple(
'',
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.ARRAY: 'ARRAY'>), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'history': list([
]),
'model': 'models/gemini-2.0-flash',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': 'Please call the test function',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': list([
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
]),
}),
),
])
# ---
# name: test_function_call_without_parameters
list([
tuple(
'',
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'history': list([
]),
'model': 'models/gemini-2.0-flash',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': 'Please call the test function',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': list([
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
]),
}),
),
])
# ---
# name: test_use_google_search
list([
tuple(
'',
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.ARRAY: 'ARRAY'>), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'history': list([
]),
'model': 'models/gemini-2.0-flash',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': 'Please call the test function',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': list([
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
]),
}),
),
])
# ---

View File

@ -34,7 +34,7 @@ from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry
@ -339,7 +339,7 @@ async def test_options_switching(
("side_effect", "error"),
[
(
CLIENT_ERROR_500,
API_ERROR_500,
"cannot_connect",
),
(

View File

@ -1,16 +1,14 @@
"""Tests for the Google Generative AI Conversation integration conversation platform."""
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
from google.genai.types import FunctionCall
from google.genai.types import GenerateContentResponse
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import UserContent, async_get_chat_log, trace
from homeassistant.components.conversation import UserContent
from homeassistant.components.google_generative_ai_conversation.conversation import (
ERROR_GETTING_RESPONSE,
_escape_decode,
@ -18,12 +16,15 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.helpers import intent
from . import CLIENT_ERROR_500
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
mock_chat_log, # noqa: F401
)
@pytest.fixture(autouse=True)
@ -40,396 +41,44 @@ def mock_ulid_tools():
yield
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
@pytest.fixture
def mock_send_message_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
async def mock_generator(stream):
for value in stream:
yield value
with patch(
"google.genai.chats.AsyncChat.send_message_stream",
AsyncMock(),
) as mock_send_message_stream:
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
mock_send_message_stream.return_value.pop(0)
)
yield mock_send_message_stream
@pytest.mark.parametrize(
("error"),
[
(API_ERROR_500,),
(CLIENT_ERROR_BAD_REQUEST,),
],
vol.Optional("param2"): vol.Any(float, int),
vol.Optional("param3"): dict,
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(
name="test_tool",
args={
"param1": ["test_value", "param1\\'s value"],
"param2": 2.7,
},
)
def tool_call(
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
) -> dict[str, Any]:
mock_part.function_call = None
mock_part.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
assert len(mock_tool_response_parts) == 1
assert mock_tool_response_parts[0].model_dump() == {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"result": "Test response",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={
"param1": ["test_value", "param1's value"],
"param2": 2.7,
},
),
llm.LLMContext(
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id="test_device",
),
)
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
# Test conversating tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
trace.ConversationTraceEventType.TOOL_CALL,
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert [
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
] == ["test_tool"]
detail_event = trace_events[2]
assert set(detail_event["data"]["stats"].keys()) == {
"input_tokens",
"cached_input_tokens",
"output_tokens",
}
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_use_google_search(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_google_search: MockConfigEntry,
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
],
vol.Optional("param2"): vol.Any(float, int),
vol.Optional("param3"): dict,
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(
name="test_tool",
args={
"param1": ["test_value", "param1\\'s value"],
"param2": 2.7,
},
)
def tool_call(
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
) -> dict[str, Any]:
mock_part.function_call = None
mock_part.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
device_id="test_device",
)
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
async def test_function_call_without_parameters(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling without parameters."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema({})
mock_get_tools.return_value = [mock_tool]
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={})
def tool_call(
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
) -> dict[str, Any]:
mock_part.function_call = None
mock_part.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
assert len(mock_tool_response_parts) == 1
assert mock_tool_response_parts[0].model_dump() == {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"result": "Test response",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={},
),
llm.LLMContext(
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id="test_device",
),
)
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
) -> None:
"""Test exception in function calling."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): vol.All(
vol.Coerce(int), vol.Range(0, 100)
)
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
def tool_call(
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
) -> dict[str, Any]:
mock_part.function_call = None
mock_part.text = "Hi there!"
raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
assert len(mock_tool_response_parts) == 1
assert mock_tool_response_parts[0].model_dump() == {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"error": "HomeAssistantError",
"error_text": "Test tool exception",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
id="mock-tool-call",
tool_name="test_tool",
tool_args={"param1": 1},
),
llm.LLMContext(
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id="test_device",
),
)
@pytest.mark.usefixtures("mock_init_component")
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
error,
) -> None:
"""Test that client errors are caught."""
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
mock_chat.side_effect = CLIENT_ERROR_500
with patch(
"google.genai.chats.AsyncChat.send_message_stream",
new_callable=AsyncMock,
side_effect=error,
):
result = await conversation.async_converse(
hass,
"hello",
@ -437,31 +86,250 @@ async def test_error_handling(
Context(),
agent_id="conversation.google_generative_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}"
assert (
result.response.as_dict()["speech"]["plain"]["speech"] == ERROR_GETTING_RESPONSE
)
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_function_call(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Test function calling."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
# Function call stream
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"text": "Hi there!",
}
],
"role": "model",
}
}
]
),
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"function_call": {
"name": "test_tool",
"args": {
"param1": [
"test_value",
"param1\\'s value",
],
"param2": 2.7,
},
},
}
],
"role": "model",
},
"finish_reason": "STOP",
}
]
),
],
# Messages after function response is sent
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"text": "I've called the ",
}
],
"role": "model",
},
}
],
),
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"text": "test function with the provided parameters.",
}
],
"role": "model",
},
"finish_reason": "STOP",
}
],
),
],
]
mock_send_message_stream.return_value = messages
mock_chat_log.mock_tool_results(
{
"mock-tool-call": {"result": "Test response"},
}
)
result = await conversation.async_converse(
hass,
"Please call the test function",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "I've called the test function with the provided parameters."
)
mock_tool_response_parts = mock_send_message_stream.mock_calls[1][2]["message"]
assert len(mock_tool_response_parts) == 1
assert mock_tool_response_parts[0].model_dump() == {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"result": "Test response",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_google_search_tool_is_sent(
hass: HomeAssistant,
mock_config_entry_with_google_search: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Test if the Google Search tool is sent to the model."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
# Messages from the model which contain the google search answer (the usage of the Google Search tool is server side)
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"text": "The last winner ",
}
],
"role": "model",
},
}
],
),
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{"text": "of the 2024 FIFA World Cup was Argentina."}
],
"role": "model",
},
"finish_reason": "STOP",
}
],
),
],
]
mock_send_message_stream.return_value = messages
with patch(
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
) as mock_create:
mock_create.return_value.send_message_stream = mock_send_message_stream
result = await conversation.async_converse(
hass,
"Who won the 2024 FIFA World Cup?",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "The last winner of the 2024 FIFA World Cup was Argentina."
)
assert mock_create.mock_calls[0][2]["config"].tools[-1].google_search is not None
@pytest.mark.usefixtures("mock_init_component")
async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Test blocked response."""
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY"))
mock_chat.return_value = chat_response
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"text": "I've called the ",
}
],
"role": "model",
},
}
],
),
GenerateContentResponse(prompt_feedback={"block_reason_message": "SAFETY"}),
],
]
mock_send_message_stream.return_value = messages
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
"Please call the test function",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@ -473,23 +341,41 @@ async def test_blocked_response(
@pytest.mark.usefixtures("mock_init_component")
async def test_empty_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Test empty response."""
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
chat_response.candidates = [Mock(content=Mock(parts=[]))]
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [],
"role": "model",
},
}
],
),
],
]
mock_send_message_stream.return_value = messages
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
"Hello",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
@ -499,27 +385,36 @@ async def test_empty_response(
@pytest.mark.usefixtures("mock_init_component")
async def test_none_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Test empty response."""
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
chat_response.candidates = None
"""Test None response."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
[
GenerateContentResponse(),
],
]
mock_send_message_stream.return_value = messages
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
"Hello",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
ERROR_GETTING_RESPONSE
"The message got blocked due to content violations, reason: unknown"
)
@ -712,29 +607,49 @@ async def test_format_schema(openapi, genai_schema) -> None:
@pytest.mark.usefixtures("mock_init_component")
async def test_empty_content_in_chat_history(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
with (
patch("google.genai.chats.AsyncChats.create") as mock_create,
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session) as chat_log,
):
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "Hi there!"}],
"role": "model",
},
}
],
),
],
]
mock_send_message_stream.return_value = messages
# Chat preparation with two inputs, one being an empty string
first_input = "First request"
second_input = ""
chat_log.async_add_user_content(UserContent(first_input))
chat_log.async_add_user_content(UserContent(second_input))
mock_chat_log.async_add_user_content(UserContent(first_input))
mock_chat_log.async_add_user_content(UserContent(second_input))
with patch(
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
) as mock_create:
mock_create.return_value.send_message_stream = mock_send_message_stream
await conversation.async_converse(
hass,
"Second request",
session.conversation_id,
Context(),
agent_id="conversation.google_generative_ai_conversation",
"Hello",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
_, kwargs = mock_create.call_args
@ -748,33 +663,53 @@ async def test_empty_content_in_chat_history(
async def test_history_always_user_first_turn(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
) -> None:
"""Test that the user is always first in the chat history."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session) as chat_log,
):
chat_log.async_add_assistant_content_without_tools(
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
messages = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [
{
"text": " Yes, I can help with that. ",
}
],
"role": "model",
},
}
],
),
],
]
mock_send_message_stream.return_value = messages
mock_chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id="conversation.google_generative_ai_conversation",
content="Garage door left open, do you want to close it?",
)
)
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
chat_response.candidates = [Mock(content=Mock(parts=[]))]
with patch(
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
) as mock_create:
mock_create.return_value.send_message_stream = mock_send_message_stream
await conversation.async_converse(
hass,
"hello",
chat_log.conversation_id,
Context(),
agent_id="conversation.google_generative_ai_conversation",
"Hello",
mock_chat_log.conversation_id,
context,
agent_id=agent_id,
device_id="test_device",
)
_, kwargs = mock_create.call_args

View File

@ -11,7 +11,7 @@ from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry
@ -212,7 +212,7 @@ async def test_generate_content_service_error(
with (
patch(
"google.genai.models.AsyncModels.generate_content",
side_effect=CLIENT_ERROR_500,
side_effect=API_ERROR_500,
),
pytest.raises(
HomeAssistantError,
@ -311,7 +311,7 @@ async def test_generate_content_service_with_image_not_exists(
("side_effect", "state", "reauth"),
[
(
CLIENT_ERROR_500,
API_ERROR_500,
ConfigEntryState.SETUP_ERROR,
False,
),