mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Enable message Streaming in the Gemini integration. (#144937)
* Added streaming implementation * Indicate the entity supports streaming * Added tests * Removed unused snapshots --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
e4b519d77a
commit
32eb4af6ef
@ -3,16 +3,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import replace
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from google.genai.errors import APIError
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from google.genai.types import (
|
||||
AutomaticFunctionCallingConfig,
|
||||
Content,
|
||||
FunctionDeclaration,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
GoogleSearch,
|
||||
HarmCategory,
|
||||
Part,
|
||||
@ -233,6 +234,81 @@ def _convert_content(
|
||||
return Content(role="model", parts=parts)
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncGenerator[GenerateContentResponse],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
new_message = True
|
||||
try:
|
||||
async for response in result:
|
||||
LOGGER.debug("Received response chunk: %s", response)
|
||||
chunk: conversation.AssistantContentDeltaDict = {}
|
||||
|
||||
if new_message:
|
||||
chunk["role"] = "assistant"
|
||||
new_message = False
|
||||
|
||||
# According to the API docs, this would mean no candidate is returned, so we can safely throw an error here.
|
||||
if response.prompt_feedback or not response.candidates:
|
||||
reason = (
|
||||
response.prompt_feedback.block_reason_message
|
||||
if response.prompt_feedback
|
||||
else "unknown"
|
||||
)
|
||||
raise HomeAssistantError(
|
||||
f"The message got blocked due to content violations, reason: {reason}"
|
||||
)
|
||||
|
||||
candidate = response.candidates[0]
|
||||
|
||||
if (
|
||||
candidate.finish_reason is not None
|
||||
and candidate.finish_reason != "STOP"
|
||||
):
|
||||
# The message ended due to a content error as explained in: https://ai.google.dev/api/generate-content#FinishReason
|
||||
LOGGER.error(
|
||||
"Error in Google Generative AI response: %s, see: https://ai.google.dev/api/generate-content#FinishReason",
|
||||
candidate.finish_reason,
|
||||
)
|
||||
raise HomeAssistantError(
|
||||
f"{ERROR_GETTING_RESPONSE} Reason: {candidate.finish_reason}"
|
||||
)
|
||||
|
||||
response_parts = (
|
||||
candidate.content.parts
|
||||
if candidate.content is not None and candidate.content.parts is not None
|
||||
else []
|
||||
)
|
||||
|
||||
content = "".join([part.text for part in response_parts if part.text])
|
||||
tool_calls = []
|
||||
for part in response_parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
tool_call = part.function_call
|
||||
tool_name = tool_call.name if tool_call.name else ""
|
||||
tool_args = _escape_decode(tool_call.args)
|
||||
tool_calls.append(
|
||||
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
chunk["tool_calls"] = tool_calls
|
||||
|
||||
chunk["content"] = content
|
||||
yield chunk
|
||||
except (
|
||||
APIError,
|
||||
ValueError,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||
if isinstance(err, APIError):
|
||||
message = err.message
|
||||
else:
|
||||
message = type(err).__name__
|
||||
error = f"{ERROR_GETTING_RESPONSE}: {message}"
|
||||
raise HomeAssistantError(error) from err
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -240,6 +316,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_supports_streaming = True
|
||||
|
||||
def __init__(self, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
@ -426,80 +503,40 @@ class GoogleGenerativeAIConversationEntity(
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
chat_response = await chat.send_message(message=chat_request)
|
||||
|
||||
if chat_response.prompt_feedback:
|
||||
raise HomeAssistantError(
|
||||
f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}"
|
||||
)
|
||||
if not chat_response.candidates:
|
||||
LOGGER.error(
|
||||
"No candidates found in the response: %s",
|
||||
chat_response,
|
||||
)
|
||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||
|
||||
chat_response_generator = await chat.send_message_stream(
|
||||
message=chat_request
|
||||
)
|
||||
except (
|
||||
APIError,
|
||||
ClientError,
|
||||
ValueError,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||
error = ERROR_GETTING_RESPONSE
|
||||
raise HomeAssistantError(error) from err
|
||||
|
||||
if (usage_metadata := chat_response.usage_metadata) is not None:
|
||||
chat_log.async_trace(
|
||||
{
|
||||
"stats": {
|
||||
"input_tokens": usage_metadata.prompt_token_count,
|
||||
"cached_input_tokens": usage_metadata.cached_content_token_count
|
||||
or 0,
|
||||
"output_tokens": usage_metadata.candidates_token_count,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response_parts = chat_response.candidates[0].content.parts
|
||||
if not response_parts:
|
||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||
content = " ".join(
|
||||
[part.text.strip() for part in response_parts if part.text]
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for part in response_parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
tool_call = part.function_call
|
||||
tool_name = tool_call.name
|
||||
tool_args = _escape_decode(tool_call.args)
|
||||
tool_calls.append(
|
||||
llm.ToolInput(
|
||||
tool_name=self._fix_tool_name(tool_name),
|
||||
tool_args=tool_args,
|
||||
)
|
||||
)
|
||||
|
||||
chat_request = _create_google_tool_response_parts(
|
||||
[
|
||||
tool_response
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
content
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id,
|
||||
_transform_stream(chat_response_generator),
|
||||
)
|
||||
if isinstance(content, conversation.ToolResultContent)
|
||||
]
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
response.async_set_speech(
|
||||
" ".join([part.text.strip() for part in response_parts if part.text])
|
||||
)
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||
chat_log.content[-1],
|
||||
)
|
||||
raise HomeAssistantError(f"{ERROR_GETTING_RESPONSE}")
|
||||
response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.genai.errors import ClientError
|
||||
from google.genai.errors import APIError, ClientError
|
||||
import httpx
|
||||
|
||||
CLIENT_ERROR_500 = ClientError(
|
||||
API_ERROR_500 = APIError(
|
||||
500,
|
||||
Mock(
|
||||
__class__=httpx.Response,
|
||||
@ -17,6 +17,18 @@ CLIENT_ERROR_500 = ClientError(
|
||||
),
|
||||
),
|
||||
)
|
||||
CLIENT_ERROR_BAD_REQUEST = ClientError(
|
||||
400,
|
||||
Mock(
|
||||
__class__=httpx.Response,
|
||||
json=Mock(
|
||||
return_value={
|
||||
"message": "Bad Request",
|
||||
"status": "invalid-argument",
|
||||
}
|
||||
),
|
||||
),
|
||||
)
|
||||
CLIENT_ERROR_API_KEY_INVALID = ClientError(
|
||||
400,
|
||||
Mock(
|
||||
|
@ -1,100 +0,0 @@
|
||||
# serializer version: 1
|
||||
# name: test_function_call
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.ARRAY: 'ARRAY'>), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||
'history': list([
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': 'Please call the test function',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': list([
|
||||
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_function_call_without_parameters
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||
'history': list([
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': 'Please call the test function',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': list([
|
||||
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_use_google_search
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.ARRAY: 'ARRAY'>), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||
'history': list([
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': 'Please call the test function',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': list([
|
||||
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
@ -34,7 +34,7 @@ from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -339,7 +339,7 @@ async def test_options_switching(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(
|
||||
CLIENT_ERROR_500,
|
||||
API_ERROR_500,
|
||||
"cannot_connect",
|
||||
),
|
||||
(
|
||||
|
@ -1,16 +1,14 @@
|
||||
"""Tests for the Google Generative AI Conversation integration conversation platform."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from google.genai.types import FunctionCall
|
||||
from google.genai.types import GenerateContentResponse
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import UserContent, async_get_chat_log, trace
|
||||
from homeassistant.components.conversation import UserContent
|
||||
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||
ERROR_GETTING_RESPONSE,
|
||||
_escape_decode,
|
||||
@ -18,12 +16,15 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, intent, llm
|
||||
from homeassistant.helpers import intent
|
||||
|
||||
from . import CLIENT_ERROR_500
|
||||
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.conversation import (
|
||||
MockChatLog,
|
||||
mock_chat_log, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -40,396 +41,44 @@ def mock_ulid_tools():
|
||||
yield
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
@pytest.fixture
|
||||
def mock_send_message_stream() -> Generator[AsyncMock]:
|
||||
"""Mock stream response."""
|
||||
|
||||
async def mock_generator(stream):
|
||||
for value in stream:
|
||||
yield value
|
||||
|
||||
with patch(
|
||||
"google.genai.chats.AsyncChat.send_message_stream",
|
||||
AsyncMock(),
|
||||
) as mock_send_message_stream:
|
||||
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
|
||||
mock_send_message_stream.return_value.pop(0)
|
||||
)
|
||||
|
||||
yield mock_send_message_stream
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error"),
|
||||
[
|
||||
(API_ERROR_500,),
|
||||
(CLIENT_ERROR_BAD_REQUEST,),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function calling."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional("param1", description="Test parameters"): [
|
||||
vol.All(str, vol.Lower)
|
||||
],
|
||||
vol.Optional("param2"): vol.Any(float, int),
|
||||
vol.Optional("param3"): dict,
|
||||
}
|
||||
)
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(
|
||||
name="test_tool",
|
||||
args={
|
||||
"param1": ["test_value", "param1\\'s value"],
|
||||
"param2": 2.7,
|
||||
},
|
||||
)
|
||||
|
||||
def tool_call(
|
||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
||||
) -> dict[str, Any]:
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
|
||||
assert len(mock_tool_response_parts) == 1
|
||||
assert mock_tool_response_parts[0].model_dump() == {
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"result": "Test response",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
}
|
||||
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={
|
||||
"param1": ["test_value", "param1's value"],
|
||||
"param2": 2.7,
|
||||
},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="google_generative_ai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
),
|
||||
)
|
||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||
|
||||
# Test conversating tracing
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
||||
trace.ConversationTraceEventType.TOOL_CALL,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
assert [
|
||||
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
||||
] == ["test_tool"]
|
||||
|
||||
detail_event = trace_events[2]
|
||||
assert set(detail_event["data"]["stats"].keys()) == {
|
||||
"input_tokens",
|
||||
"cached_input_tokens",
|
||||
"output_tokens",
|
||||
}
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_use_google_search(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_google_search: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function calling."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional("param1", description="Test parameters"): [
|
||||
vol.All(str, vol.Lower)
|
||||
],
|
||||
vol.Optional("param2"): vol.Any(float, int),
|
||||
vol.Optional("param3"): dict,
|
||||
}
|
||||
)
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(
|
||||
name="test_tool",
|
||||
args={
|
||||
"param1": ["test_value", "param1\\'s value"],
|
||||
"param2": 2.7,
|
||||
},
|
||||
)
|
||||
|
||||
def tool_call(
|
||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
||||
) -> dict[str, Any]:
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_function_call_without_parameters(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function calling without parameters."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema({})
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
||||
|
||||
def tool_call(
|
||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
||||
) -> dict[str, Any]:
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
|
||||
assert len(mock_tool_response_parts) == 1
|
||||
assert mock_tool_response_parts[0].model_dump() == {
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"result": "Test response",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
}
|
||||
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="google_generative_ai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
),
|
||||
)
|
||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_function_exception(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test exception in function calling."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional("param1", description="Test parameters"): vol.All(
|
||||
vol.Coerce(int), vol.Range(0, 100)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
||||
|
||||
def tool_call(
|
||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
||||
) -> dict[str, Any]:
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
raise HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
|
||||
assert len(mock_tool_response_parts) == 1
|
||||
assert mock_tool_response_parts[0].model_dump() == {
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"error": "HomeAssistantError",
|
||||
"error_text": "Test tool exception",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
}
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": 1},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="google_generative_ai_conversation",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
error,
|
||||
) -> None:
|
||||
"""Test that client errors are caught."""
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
mock_chat.side_effect = CLIENT_ERROR_500
|
||||
with patch(
|
||||
"google.genai.chats.AsyncChat.send_message_stream",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=error,
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
@ -437,32 +86,251 @@ async def test_error_handling(
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}"
|
||||
assert (
|
||||
result.response.as_dict()["speech"]["plain"]["speech"] == ERROR_GETTING_RESPONSE
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_function_call(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test function calling."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
# Function call stream
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "Hi there!",
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
}
|
||||
}
|
||||
]
|
||||
),
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "test_tool",
|
||||
"args": {
|
||||
"param1": [
|
||||
"test_value",
|
||||
"param1\\'s value",
|
||||
],
|
||||
"param2": 2.7,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
"finish_reason": "STOP",
|
||||
}
|
||||
]
|
||||
),
|
||||
],
|
||||
# Messages after function response is sent
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "I've called the ",
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "test function with the provided parameters.",
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
"finish_reason": "STOP",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
mock_chat_log.mock_tool_results(
|
||||
{
|
||||
"mock-tool-call": {"result": "Test response"},
|
||||
}
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||
== "I've called the test function with the provided parameters."
|
||||
)
|
||||
mock_tool_response_parts = mock_send_message_stream.mock_calls[1][2]["message"]
|
||||
assert len(mock_tool_response_parts) == 1
|
||||
assert mock_tool_response_parts[0].model_dump() == {
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"result": "Test response",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_google_search_tool_is_sent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_google_search: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test if the Google Search tool is sent to the model."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
# Messages from the model which contain the google search answer (the usage of the Google Search tool is server side)
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "The last winner ",
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"text": "of the 2024 FIFA World Cup was Argentina."}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
"finish_reason": "STOP",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
with patch(
|
||||
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
|
||||
) as mock_create:
|
||||
mock_create.return_value.send_message_stream = mock_send_message_stream
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Who won the 2024 FIFA World Cup?",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||
== "The last winner of the 2024 FIFA World Cup was Argentina."
|
||||
)
|
||||
assert mock_create.mock_calls[0][2]["config"].tools[-1].google_search is not None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_blocked_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test blocked response."""
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY"))
|
||||
mock_chat.return_value = chat_response
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "I've called the ",
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
GenerateContentResponse(prompt_feedback={"block_reason_message": "SAFETY"}),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
@ -473,23 +341,41 @@ async def test_blocked_response(
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_empty_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Hello",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
@ -499,27 +385,36 @@ async def test_empty_response(
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_none_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
chat_response.candidates = None
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
"""Test None response."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
[
|
||||
GenerateContentResponse(),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Hello",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
ERROR_GETTING_RESPONSE
|
||||
"The message got blocked due to content violations, reason: unknown"
|
||||
)
|
||||
|
||||
|
||||
@ -712,69 +607,109 @@ async def test_format_schema(openapi, genai_schema) -> None:
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_empty_content_in_chat_history(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
|
||||
with (
|
||||
patch("google.genai.chats.AsyncChats.create") as mock_create,
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
# Chat preparation with two inputs, one being an empty string
|
||||
first_input = "First request"
|
||||
second_input = ""
|
||||
chat_log.async_add_user_content(UserContent(first_input))
|
||||
chat_log.async_add_user_content(UserContent(second_input))
|
||||
messages = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "Hi there!"}],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
# Chat preparation with two inputs, one being an empty string
|
||||
first_input = "First request"
|
||||
second_input = ""
|
||||
mock_chat_log.async_add_user_content(UserContent(first_input))
|
||||
mock_chat_log.async_add_user_content(UserContent(second_input))
|
||||
|
||||
with patch(
|
||||
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
|
||||
) as mock_create:
|
||||
mock_create.return_value.send_message_stream = mock_send_message_stream
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
"Second request",
|
||||
session.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
"Hello",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
_, kwargs = mock_create.call_args
|
||||
actual_history = kwargs.get("history")
|
||||
_, kwargs = mock_create.call_args
|
||||
actual_history = kwargs.get("history")
|
||||
|
||||
assert actual_history[0].parts[0].text == first_input
|
||||
assert actual_history[1].parts[0].text == " "
|
||||
assert actual_history[0].parts[0].text == first_input
|
||||
assert actual_history[1].parts[0].text == " "
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_history_always_user_first_turn(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that the user is always first in the chat history."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
content="Garage door left open, do you want to close it?",
|
||||
)
|
||||
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": " Yes, I can help with that. ",
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
mock_send_message_stream.return_value = messages
|
||||
|
||||
mock_chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
content="Garage door left open, do you want to close it?",
|
||||
)
|
||||
)
|
||||
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[]))]
|
||||
|
||||
with patch(
|
||||
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
|
||||
) as mock_create:
|
||||
mock_create.return_value.send_message_stream = mock_send_message_stream
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
"Hello",
|
||||
mock_chat_log.conversation_id,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
_, kwargs = mock_create.call_args
|
||||
|
@ -11,7 +11,7 @@ from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -212,7 +212,7 @@ async def test_generate_content_service_error(
|
||||
with (
|
||||
patch(
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
side_effect=CLIENT_ERROR_500,
|
||||
side_effect=API_ERROR_500,
|
||||
),
|
||||
pytest.raises(
|
||||
HomeAssistantError,
|
||||
@ -311,7 +311,7 @@ async def test_generate_content_service_with_image_not_exists(
|
||||
("side_effect", "state", "reauth"),
|
||||
[
|
||||
(
|
||||
CLIENT_ERROR_500,
|
||||
API_ERROR_500,
|
||||
ConfigEntryState.SETUP_ERROR,
|
||||
False,
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user