mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Enable message Streaming in the Gemini integration. (#144937)
* Added streaming implementation * Indicate the entity supports streaming * Added tests * Removed unused snapshots --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
e4b519d77a
commit
32eb4af6ef
@ -3,16 +3,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
from collections.abc import Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from google.genai.errors import APIError
|
from google.genai.errors import APIError, ClientError
|
||||||
from google.genai.types import (
|
from google.genai.types import (
|
||||||
AutomaticFunctionCallingConfig,
|
AutomaticFunctionCallingConfig,
|
||||||
Content,
|
Content,
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
|
GenerateContentResponse,
|
||||||
GoogleSearch,
|
GoogleSearch,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
Part,
|
Part,
|
||||||
@ -233,6 +234,81 @@ def _convert_content(
|
|||||||
return Content(role="model", parts=parts)
|
return Content(role="model", parts=parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_stream(
|
||||||
|
result: AsyncGenerator[GenerateContentResponse],
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
new_message = True
|
||||||
|
try:
|
||||||
|
async for response in result:
|
||||||
|
LOGGER.debug("Received response chunk: %s", response)
|
||||||
|
chunk: conversation.AssistantContentDeltaDict = {}
|
||||||
|
|
||||||
|
if new_message:
|
||||||
|
chunk["role"] = "assistant"
|
||||||
|
new_message = False
|
||||||
|
|
||||||
|
# According to the API docs, this would mean no candidate is returned, so we can safely throw an error here.
|
||||||
|
if response.prompt_feedback or not response.candidates:
|
||||||
|
reason = (
|
||||||
|
response.prompt_feedback.block_reason_message
|
||||||
|
if response.prompt_feedback
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"The message got blocked due to content violations, reason: {reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
|
||||||
|
if (
|
||||||
|
candidate.finish_reason is not None
|
||||||
|
and candidate.finish_reason != "STOP"
|
||||||
|
):
|
||||||
|
# The message ended due to a content error as explained in: https://ai.google.dev/api/generate-content#FinishReason
|
||||||
|
LOGGER.error(
|
||||||
|
"Error in Google Generative AI response: %s, see: https://ai.google.dev/api/generate-content#FinishReason",
|
||||||
|
candidate.finish_reason,
|
||||||
|
)
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"{ERROR_GETTING_RESPONSE} Reason: {candidate.finish_reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response_parts = (
|
||||||
|
candidate.content.parts
|
||||||
|
if candidate.content is not None and candidate.content.parts is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
content = "".join([part.text for part in response_parts if part.text])
|
||||||
|
tool_calls = []
|
||||||
|
for part in response_parts:
|
||||||
|
if not part.function_call:
|
||||||
|
continue
|
||||||
|
tool_call = part.function_call
|
||||||
|
tool_name = tool_call.name if tool_call.name else ""
|
||||||
|
tool_args = _escape_decode(tool_call.args)
|
||||||
|
tool_calls.append(
|
||||||
|
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
chunk["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
chunk["content"] = content
|
||||||
|
yield chunk
|
||||||
|
except (
|
||||||
|
APIError,
|
||||||
|
ValueError,
|
||||||
|
) as err:
|
||||||
|
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||||
|
if isinstance(err, APIError):
|
||||||
|
message = err.message
|
||||||
|
else:
|
||||||
|
message = type(err).__name__
|
||||||
|
error = f"{ERROR_GETTING_RESPONSE}: {message}"
|
||||||
|
raise HomeAssistantError(error) from err
|
||||||
|
|
||||||
|
|
||||||
class GoogleGenerativeAIConversationEntity(
|
class GoogleGenerativeAIConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -240,6 +316,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
|
|
||||||
_attr_has_entity_name = True
|
_attr_has_entity_name = True
|
||||||
_attr_name = None
|
_attr_name = None
|
||||||
|
_attr_supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, entry: ConfigEntry) -> None:
|
def __init__(self, entry: ConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
@ -426,80 +503,40 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
chat_response = await chat.send_message(message=chat_request)
|
chat_response_generator = await chat.send_message_stream(
|
||||||
|
message=chat_request
|
||||||
if chat_response.prompt_feedback:
|
)
|
||||||
raise HomeAssistantError(
|
|
||||||
f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}"
|
|
||||||
)
|
|
||||||
if not chat_response.candidates:
|
|
||||||
LOGGER.error(
|
|
||||||
"No candidates found in the response: %s",
|
|
||||||
chat_response,
|
|
||||||
)
|
|
||||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
|
||||||
|
|
||||||
except (
|
except (
|
||||||
APIError,
|
APIError,
|
||||||
|
ClientError,
|
||||||
ValueError,
|
ValueError,
|
||||||
) as err:
|
) as err:
|
||||||
LOGGER.error("Error sending message: %s %s", type(err), err)
|
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||||
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
error = ERROR_GETTING_RESPONSE
|
||||||
raise HomeAssistantError(error) from err
|
raise HomeAssistantError(error) from err
|
||||||
|
|
||||||
if (usage_metadata := chat_response.usage_metadata) is not None:
|
|
||||||
chat_log.async_trace(
|
|
||||||
{
|
|
||||||
"stats": {
|
|
||||||
"input_tokens": usage_metadata.prompt_token_count,
|
|
||||||
"cached_input_tokens": usage_metadata.cached_content_token_count
|
|
||||||
or 0,
|
|
||||||
"output_tokens": usage_metadata.candidates_token_count,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
response_parts = chat_response.candidates[0].content.parts
|
|
||||||
if not response_parts:
|
|
||||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
|
||||||
content = " ".join(
|
|
||||||
[part.text.strip() for part in response_parts if part.text]
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
for part in response_parts:
|
|
||||||
if not part.function_call:
|
|
||||||
continue
|
|
||||||
tool_call = part.function_call
|
|
||||||
tool_name = tool_call.name
|
|
||||||
tool_args = _escape_decode(tool_call.args)
|
|
||||||
tool_calls.append(
|
|
||||||
llm.ToolInput(
|
|
||||||
tool_name=self._fix_tool_name(tool_name),
|
|
||||||
tool_args=tool_args,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_request = _create_google_tool_response_parts(
|
chat_request = _create_google_tool_response_parts(
|
||||||
[
|
[
|
||||||
tool_response
|
content
|
||||||
async for tool_response in chat_log.async_add_assistant_content(
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
conversation.AssistantContent(
|
user_input.agent_id,
|
||||||
agent_id=user_input.agent_id,
|
_transform_stream(chat_response_generator),
|
||||||
content=content,
|
|
||||||
tool_calls=tool_calls or None,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
if isinstance(content, conversation.ToolResultContent)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not tool_calls:
|
if not chat_log.unresponded_tool_results:
|
||||||
break
|
break
|
||||||
|
|
||||||
response = intent.IntentResponse(language=user_input.language)
|
response = intent.IntentResponse(language=user_input.language)
|
||||||
response.async_set_speech(
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
" ".join([part.text.strip() for part in response_parts if part.text])
|
LOGGER.error(
|
||||||
)
|
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||||
|
chat_log.content[-1],
|
||||||
|
)
|
||||||
|
raise HomeAssistantError(f"{ERROR_GETTING_RESPONSE}")
|
||||||
|
response.async_set_speech(chat_log.content[-1].content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=response,
|
response=response,
|
||||||
conversation_id=chat_log.conversation_id,
|
conversation_id=chat_log.conversation_id,
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from google.genai.errors import ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
CLIENT_ERROR_500 = ClientError(
|
API_ERROR_500 = APIError(
|
||||||
500,
|
500,
|
||||||
Mock(
|
Mock(
|
||||||
__class__=httpx.Response,
|
__class__=httpx.Response,
|
||||||
@ -17,6 +17,18 @@ CLIENT_ERROR_500 = ClientError(
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
CLIENT_ERROR_BAD_REQUEST = ClientError(
|
||||||
|
400,
|
||||||
|
Mock(
|
||||||
|
__class__=httpx.Response,
|
||||||
|
json=Mock(
|
||||||
|
return_value={
|
||||||
|
"message": "Bad Request",
|
||||||
|
"status": "invalid-argument",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
CLIENT_ERROR_API_KEY_INVALID = ClientError(
|
CLIENT_ERROR_API_KEY_INVALID = ClientError(
|
||||||
400,
|
400,
|
||||||
Mock(
|
Mock(
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
# serializer version: 1
|
|
||||||
# name: test_function_call
|
|
||||||
list([
|
|
||||||
tuple(
|
|
||||||
'',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.ARRAY: 'ARRAY'>), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
|
||||||
'history': list([
|
|
||||||
]),
|
|
||||||
'model': 'models/gemini-2.0-flash',
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
tuple(
|
|
||||||
'().send_message',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'message': 'Please call the test function',
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
tuple(
|
|
||||||
'().send_message',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'message': list([
|
|
||||||
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
])
|
|
||||||
# ---
|
|
||||||
# name: test_function_call_without_parameters
|
|
||||||
list([
|
|
||||||
tuple(
|
|
||||||
'',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
|
||||||
'history': list([
|
|
||||||
]),
|
|
||||||
'model': 'models/gemini-2.0-flash',
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
tuple(
|
|
||||||
'().send_message',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'message': 'Please call the test function',
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
tuple(
|
|
||||||
'().send_message',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'message': list([
|
|
||||||
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
])
|
|
||||||
# ---
|
|
||||||
# name: test_use_google_search
|
|
||||||
list([
|
|
||||||
tuple(
|
|
||||||
'',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'config': GenerateContentConfig(http_options=None, system_instruction="You are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.\nCurrent time is 05:00:00. Today's date is 2024-05-24.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=1500, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'param1': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description='Test parameters', enum=None, format=None, items=Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>), max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.ARRAY: 'ARRAY'>), 'param2': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=None), 'param3': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties={'json': Schema(example=None, pattern=None, default=None, max_length=None, title=None, min_length=None, min_properties=None, max_properties=None, any_of=None, description=None, enum=None, format=None, items=None, max_items=None, maximum=None, min_items=None, minimum=None, nullable=None, properties=None, property_ordering=None, required=None, type=<Type.STRING: 'STRING'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>)}, property_ordering=None, required=[], type=<Type.OBJECT: 'OBJECT'>))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
|
||||||
'history': list([
|
|
||||||
]),
|
|
||||||
'model': 'models/gemini-2.0-flash',
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
tuple(
|
|
||||||
'().send_message',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'message': 'Please call the test function',
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
tuple(
|
|
||||||
'().send_message',
|
|
||||||
tuple(
|
|
||||||
),
|
|
||||||
dict({
|
|
||||||
'message': list([
|
|
||||||
Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
])
|
|
||||||
# ---
|
|
@ -34,7 +34,7 @@ from homeassistant.const import CONF_LLM_HASS_API
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -339,7 +339,7 @@ async def test_options_switching(
|
|||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
CLIENT_ERROR_500,
|
API_ERROR_500,
|
||||||
"cannot_connect",
|
"cannot_connect",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
"""Tests for the Google Generative AI Conversation integration conversation platform."""
|
"""Tests for the Google Generative AI Conversation integration conversation platform."""
|
||||||
|
|
||||||
from typing import Any
|
from collections.abc import Generator
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from google.genai.types import FunctionCall
|
from google.genai.types import GenerateContentResponse
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
|
||||||
import voluptuous as vol
|
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import UserContent, async_get_chat_log, trace
|
from homeassistant.components.conversation import UserContent
|
||||||
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||||
ERROR_GETTING_RESPONSE,
|
ERROR_GETTING_RESPONSE,
|
||||||
_escape_decode,
|
_escape_decode,
|
||||||
@ -18,12 +16,15 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
|
|||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.helpers import chat_session, intent, llm
|
|
||||||
|
|
||||||
from . import CLIENT_ERROR_500
|
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.components.conversation import (
|
||||||
|
MockChatLog,
|
||||||
|
mock_chat_log, # noqa: F401
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -40,396 +41,44 @@ def mock_ulid_tools():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@pytest.fixture
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
def mock_send_message_stream() -> Generator[AsyncMock]:
|
||||||
|
"""Mock stream response."""
|
||||||
|
|
||||||
|
async def mock_generator(stream):
|
||||||
|
for value in stream:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"google.genai.chats.AsyncChat.send_message_stream",
|
||||||
|
AsyncMock(),
|
||||||
|
) as mock_send_message_stream:
|
||||||
|
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
|
||||||
|
mock_send_message_stream.return_value.pop(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield mock_send_message_stream
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("error"),
|
||||||
|
[
|
||||||
|
(API_ERROR_500,),
|
||||||
|
(CLIENT_ERROR_BAD_REQUEST,),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
|
||||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
|
||||||
async def test_function_call(
|
|
||||||
mock_get_tools,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test function calling."""
|
|
||||||
agent_id = "conversation.google_generative_ai_conversation"
|
|
||||||
context = Context()
|
|
||||||
|
|
||||||
mock_tool = AsyncMock()
|
|
||||||
mock_tool.name = "test_tool"
|
|
||||||
mock_tool.description = "Test function"
|
|
||||||
mock_tool.parameters = vol.Schema(
|
|
||||||
{
|
|
||||||
vol.Optional("param1", description="Test parameters"): [
|
|
||||||
vol.All(str, vol.Lower)
|
|
||||||
],
|
|
||||||
vol.Optional("param2"): vol.Any(float, int),
|
|
||||||
vol.Optional("param3"): dict,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_get_tools.return_value = [mock_tool]
|
|
||||||
|
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
|
||||||
mock_chat = AsyncMock()
|
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=None)
|
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
mock_part = Mock()
|
|
||||||
mock_part.text = ""
|
|
||||||
mock_part.function_call = FunctionCall(
|
|
||||||
name="test_tool",
|
|
||||||
args={
|
|
||||||
"param1": ["test_value", "param1\\'s value"],
|
|
||||||
"param2": 2.7,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def tool_call(
|
|
||||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
mock_part.function_call = None
|
|
||||||
mock_part.text = "Hi there!"
|
|
||||||
return {"result": "Test response"}
|
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
|
||||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
device_id="test_device",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
|
||||||
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
|
|
||||||
assert len(mock_tool_response_parts) == 1
|
|
||||||
assert mock_tool_response_parts[0].model_dump() == {
|
|
||||||
"code_execution_result": None,
|
|
||||||
"executable_code": None,
|
|
||||||
"file_data": None,
|
|
||||||
"function_call": None,
|
|
||||||
"function_response": {
|
|
||||||
"id": None,
|
|
||||||
"name": "test_tool",
|
|
||||||
"response": {
|
|
||||||
"result": "Test response",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"inline_data": None,
|
|
||||||
"text": None,
|
|
||||||
"thought": None,
|
|
||||||
"video_metadata": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
|
||||||
hass,
|
|
||||||
llm.ToolInput(
|
|
||||||
id="mock-tool-call",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={
|
|
||||||
"param1": ["test_value", "param1's value"],
|
|
||||||
"param2": 2.7,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
llm.LLMContext(
|
|
||||||
platform="google_generative_ai_conversation",
|
|
||||||
context=context,
|
|
||||||
user_prompt="Please call the test function",
|
|
||||||
language="en",
|
|
||||||
assistant="conversation",
|
|
||||||
device_id="test_device",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
|
||||||
|
|
||||||
# Test conversating tracing
|
|
||||||
traces = trace.async_get_traces()
|
|
||||||
assert traces
|
|
||||||
last_trace = traces[-1].as_dict()
|
|
||||||
trace_events = last_trace.get("events", [])
|
|
||||||
assert [event["event_type"] for event in trace_events] == [
|
|
||||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL, # prompt and tools
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
|
||||||
trace.ConversationTraceEventType.TOOL_CALL,
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL, # stats for response
|
|
||||||
]
|
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
|
||||||
detail_event = trace_events[1]
|
|
||||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
|
||||||
assert [
|
|
||||||
p["tool_name"] for p in detail_event["data"]["messages"][2]["tool_calls"]
|
|
||||||
] == ["test_tool"]
|
|
||||||
|
|
||||||
detail_event = trace_events[2]
|
|
||||||
assert set(detail_event["data"]["stats"].keys()) == {
|
|
||||||
"input_tokens",
|
|
||||||
"cached_input_tokens",
|
|
||||||
"output_tokens",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
|
||||||
)
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
|
||||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
|
||||||
async def test_use_google_search(
|
|
||||||
mock_get_tools,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_config_entry_with_google_search: MockConfigEntry,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test function calling."""
|
|
||||||
agent_id = "conversation.google_generative_ai_conversation"
|
|
||||||
context = Context()
|
|
||||||
|
|
||||||
mock_tool = AsyncMock()
|
|
||||||
mock_tool.name = "test_tool"
|
|
||||||
mock_tool.description = "Test function"
|
|
||||||
mock_tool.parameters = vol.Schema(
|
|
||||||
{
|
|
||||||
vol.Optional("param1", description="Test parameters"): [
|
|
||||||
vol.All(str, vol.Lower)
|
|
||||||
],
|
|
||||||
vol.Optional("param2"): vol.Any(float, int),
|
|
||||||
vol.Optional("param3"): dict,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_get_tools.return_value = [mock_tool]
|
|
||||||
|
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
|
||||||
mock_chat = AsyncMock()
|
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=None)
|
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
mock_part = Mock()
|
|
||||||
mock_part.text = ""
|
|
||||||
mock_part.function_call = FunctionCall(
|
|
||||||
name="test_tool",
|
|
||||||
args={
|
|
||||||
"param1": ["test_value", "param1\\'s value"],
|
|
||||||
"param2": 2.7,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def tool_call(
|
|
||||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
mock_part.function_call = None
|
|
||||||
mock_part.text = "Hi there!"
|
|
||||||
return {"result": "Test response"}
|
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
|
||||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
|
||||||
await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
device_id="test_device",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
|
||||||
)
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
|
||||||
async def test_function_call_without_parameters(
|
|
||||||
mock_get_tools,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test function calling without parameters."""
|
|
||||||
agent_id = "conversation.google_generative_ai_conversation"
|
|
||||||
context = Context()
|
|
||||||
|
|
||||||
mock_tool = AsyncMock()
|
|
||||||
mock_tool.name = "test_tool"
|
|
||||||
mock_tool.description = "Test function"
|
|
||||||
mock_tool.parameters = vol.Schema({})
|
|
||||||
|
|
||||||
mock_get_tools.return_value = [mock_tool]
|
|
||||||
|
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
|
||||||
mock_chat = AsyncMock()
|
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=None)
|
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
mock_part = Mock()
|
|
||||||
mock_part.text = ""
|
|
||||||
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
|
||||||
|
|
||||||
def tool_call(
|
|
||||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
mock_part.function_call = None
|
|
||||||
mock_part.text = "Hi there!"
|
|
||||||
return {"result": "Test response"}
|
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
|
||||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
device_id="test_device",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
|
||||||
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
|
|
||||||
assert len(mock_tool_response_parts) == 1
|
|
||||||
assert mock_tool_response_parts[0].model_dump() == {
|
|
||||||
"code_execution_result": None,
|
|
||||||
"executable_code": None,
|
|
||||||
"file_data": None,
|
|
||||||
"function_call": None,
|
|
||||||
"function_response": {
|
|
||||||
"id": None,
|
|
||||||
"name": "test_tool",
|
|
||||||
"response": {
|
|
||||||
"result": "Test response",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"inline_data": None,
|
|
||||||
"text": None,
|
|
||||||
"thought": None,
|
|
||||||
"video_metadata": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
|
||||||
hass,
|
|
||||||
llm.ToolInput(
|
|
||||||
id="mock-tool-call",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={},
|
|
||||||
),
|
|
||||||
llm.LLMContext(
|
|
||||||
platform="google_generative_ai_conversation",
|
|
||||||
context=context,
|
|
||||||
user_prompt="Please call the test function",
|
|
||||||
language="en",
|
|
||||||
assistant="conversation",
|
|
||||||
device_id="test_device",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
|
||||||
)
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
|
||||||
async def test_function_exception(
|
|
||||||
mock_get_tools,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
|
||||||
) -> None:
|
|
||||||
"""Test exception in function calling."""
|
|
||||||
agent_id = "conversation.google_generative_ai_conversation"
|
|
||||||
context = Context()
|
|
||||||
|
|
||||||
mock_tool = AsyncMock()
|
|
||||||
mock_tool.name = "test_tool"
|
|
||||||
mock_tool.description = "Test function"
|
|
||||||
mock_tool.parameters = vol.Schema(
|
|
||||||
{
|
|
||||||
vol.Optional("param1", description="Test parameters"): vol.All(
|
|
||||||
vol.Coerce(int), vol.Range(0, 100)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_get_tools.return_value = [mock_tool]
|
|
||||||
|
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
|
||||||
mock_chat = AsyncMock()
|
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=None)
|
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
mock_part = Mock()
|
|
||||||
mock_part.text = ""
|
|
||||||
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
|
||||||
|
|
||||||
def tool_call(
|
|
||||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
mock_part.function_call = None
|
|
||||||
mock_part.text = "Hi there!"
|
|
||||||
raise HomeAssistantError("Test tool exception")
|
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
|
||||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
device_id="test_device",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
|
||||||
mock_tool_response_parts = mock_create.mock_calls[2][2]["message"]
|
|
||||||
assert len(mock_tool_response_parts) == 1
|
|
||||||
assert mock_tool_response_parts[0].model_dump() == {
|
|
||||||
"code_execution_result": None,
|
|
||||||
"executable_code": None,
|
|
||||||
"file_data": None,
|
|
||||||
"function_call": None,
|
|
||||||
"function_response": {
|
|
||||||
"id": None,
|
|
||||||
"name": "test_tool",
|
|
||||||
"response": {
|
|
||||||
"error": "HomeAssistantError",
|
|
||||||
"error_text": "Test tool exception",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"inline_data": None,
|
|
||||||
"text": None,
|
|
||||||
"thought": None,
|
|
||||||
"video_metadata": None,
|
|
||||||
}
|
|
||||||
mock_tool.async_call.assert_awaited_once_with(
|
|
||||||
hass,
|
|
||||||
llm.ToolInput(
|
|
||||||
id="mock-tool-call",
|
|
||||||
tool_name="test_tool",
|
|
||||||
tool_args={"param1": 1},
|
|
||||||
),
|
|
||||||
llm.LLMContext(
|
|
||||||
platform="google_generative_ai_conversation",
|
|
||||||
context=context,
|
|
||||||
user_prompt="Please call the test function",
|
|
||||||
language="en",
|
|
||||||
assistant="conversation",
|
|
||||||
device_id="test_device",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
|
||||||
async def test_error_handling(
|
async def test_error_handling(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
error,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that client errors are caught."""
|
"""Test that client errors are caught."""
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
with patch(
|
||||||
mock_chat = AsyncMock()
|
"google.genai.chats.AsyncChat.send_message_stream",
|
||||||
mock_create.return_value.send_message = mock_chat
|
new_callable=AsyncMock,
|
||||||
mock_chat.side_effect = CLIENT_ERROR_500
|
side_effect=error,
|
||||||
|
):
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
"hello",
|
"hello",
|
||||||
@ -437,32 +86,251 @@ async def test_error_handling(
|
|||||||
Context(),
|
Context(),
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
agent_id="conversation.google_generative_ai_conversation",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
assert (
|
||||||
"Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}"
|
result.response.as_dict()["speech"]["plain"]["speech"] == ERROR_GETTING_RESPONSE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||||
|
async def test_function_call(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test function calling."""
|
||||||
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
# Function call stream
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": "Hi there!",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"function_call": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"args": {
|
||||||
|
"param1": [
|
||||||
|
"test_value",
|
||||||
|
"param1\\'s value",
|
||||||
|
],
|
||||||
|
"param2": 2.7,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finish_reason": "STOP",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
# Messages after function response is sent
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": "I've called the ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": "test function with the provided parameters.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finish_reason": "STOP",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = messages
|
||||||
|
|
||||||
|
mock_chat_log.mock_tool_results(
|
||||||
|
{
|
||||||
|
"mock-tool-call": {"result": "Test response"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
|
== "I've called the test function with the provided parameters."
|
||||||
|
)
|
||||||
|
mock_tool_response_parts = mock_send_message_stream.mock_calls[1][2]["message"]
|
||||||
|
assert len(mock_tool_response_parts) == 1
|
||||||
|
assert mock_tool_response_parts[0].model_dump() == {
|
||||||
|
"code_execution_result": None,
|
||||||
|
"executable_code": None,
|
||||||
|
"file_data": None,
|
||||||
|
"function_call": None,
|
||||||
|
"function_response": {
|
||||||
|
"id": None,
|
||||||
|
"name": "test_tool",
|
||||||
|
"response": {
|
||||||
|
"result": "Test response",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"inline_data": None,
|
||||||
|
"text": None,
|
||||||
|
"thought": None,
|
||||||
|
"video_metadata": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||||
|
async def test_google_search_tool_is_sent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_google_search: MockConfigEntry,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test if the Google Search tool is sent to the model."""
|
||||||
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
# Messages from the model which contain the google search answer (the usage of the Google Search tool is server side)
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": "The last winner ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{"text": "of the 2024 FIFA World Cup was Argentina."}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finish_reason": "STOP",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = messages
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
|
||||||
|
) as mock_create:
|
||||||
|
mock_create.return_value.send_message_stream = mock_send_message_stream
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Who won the 2024 FIFA World Cup?",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
|
== "The last winner of the 2024 FIFA World Cup was Argentina."
|
||||||
|
)
|
||||||
|
assert mock_create.mock_calls[0][2]["config"].tools[-1].google_search is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_blocked_response(
|
async def test_blocked_response(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test blocked response."""
|
"""Test blocked response."""
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
mock_chat = AsyncMock()
|
context = Context()
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY"))
|
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
|
|
||||||
result = await conversation.async_converse(
|
messages = [
|
||||||
hass,
|
[
|
||||||
"hello",
|
GenerateContentResponse(
|
||||||
None,
|
candidates=[
|
||||||
Context(),
|
{
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
"content": {
|
||||||
)
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": "I've called the ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
GenerateContentResponse(prompt_feedback={"block_reason_message": "SAFETY"}),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = messages
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
@ -473,23 +341,41 @@ async def test_blocked_response(
|
|||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_empty_response(
|
async def test_empty_response(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test empty response."""
|
"""Test empty response."""
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
|
||||||
mock_chat = AsyncMock()
|
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=None)
|
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
chat_response.candidates = [Mock(content=Mock(parts=[]))]
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"hello",
|
|
||||||
None,
|
|
||||||
Context(),
|
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = messages
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Hello",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
@ -499,27 +385,36 @@ async def test_empty_response(
|
|||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_none_response(
|
async def test_none_response(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test empty response."""
|
"""Test None response."""
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
mock_chat = AsyncMock()
|
context = Context()
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
chat_response = Mock(prompt_feedback=None)
|
messages = [
|
||||||
mock_chat.return_value = chat_response
|
[
|
||||||
chat_response.candidates = None
|
GenerateContentResponse(),
|
||||||
result = await conversation.async_converse(
|
],
|
||||||
hass,
|
]
|
||||||
"hello",
|
|
||||||
None,
|
mock_send_message_stream.return_value = messages
|
||||||
Context(),
|
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
result = await conversation.async_converse(
|
||||||
)
|
hass,
|
||||||
|
"Hello",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
ERROR_GETTING_RESPONSE
|
"The message got blocked due to content violations, reason: unknown"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -712,69 +607,109 @@ async def test_format_schema(openapi, genai_schema) -> None:
|
|||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_empty_content_in_chat_history(
|
async def test_empty_content_in_chat_history(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
|
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
|
||||||
with (
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
patch("google.genai.chats.AsyncChats.create") as mock_create,
|
context = Context()
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
|
||||||
async_get_chat_log(hass, session) as chat_log,
|
|
||||||
):
|
|
||||||
mock_chat = AsyncMock()
|
|
||||||
mock_create.return_value.send_message = mock_chat
|
|
||||||
|
|
||||||
# Chat preparation with two inputs, one being an empty string
|
messages = [
|
||||||
first_input = "First request"
|
[
|
||||||
second_input = ""
|
GenerateContentResponse(
|
||||||
chat_log.async_add_user_content(UserContent(first_input))
|
candidates=[
|
||||||
chat_log.async_add_user_content(UserContent(second_input))
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [{"text": "Hi there!"}],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = messages
|
||||||
|
|
||||||
|
# Chat preparation with two inputs, one being an empty string
|
||||||
|
first_input = "First request"
|
||||||
|
second_input = ""
|
||||||
|
mock_chat_log.async_add_user_content(UserContent(first_input))
|
||||||
|
mock_chat_log.async_add_user_content(UserContent(second_input))
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
|
||||||
|
) as mock_create:
|
||||||
|
mock_create.return_value.send_message_stream = mock_send_message_stream
|
||||||
await conversation.async_converse(
|
await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
"Second request",
|
"Hello",
|
||||||
session.conversation_id,
|
mock_chat_log.conversation_id,
|
||||||
Context(),
|
context,
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
_, kwargs = mock_create.call_args
|
_, kwargs = mock_create.call_args
|
||||||
actual_history = kwargs.get("history")
|
actual_history = kwargs.get("history")
|
||||||
|
|
||||||
assert actual_history[0].parts[0].text == first_input
|
assert actual_history[0].parts[0].text == first_input
|
||||||
assert actual_history[1].parts[0].text == " "
|
assert actual_history[1].parts[0].text == " "
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_history_always_user_first_turn(
|
async def test_history_always_user_first_turn(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
snapshot: SnapshotAssertion,
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the user is always first in the chat history."""
|
"""Test that the user is always first in the chat history."""
|
||||||
with (
|
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
async_get_chat_log(hass, session) as chat_log,
|
context = Context()
|
||||||
):
|
|
||||||
chat_log.async_add_assistant_content_without_tools(
|
messages = [
|
||||||
conversation.AssistantContent(
|
[
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
GenerateContentResponse(
|
||||||
content="Garage door left open, do you want to close it?",
|
candidates=[
|
||||||
)
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": " Yes, I can help with that. ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = messages
|
||||||
|
|
||||||
|
mock_chat_log.async_add_assistant_content_without_tools(
|
||||||
|
conversation.AssistantContent(
|
||||||
|
agent_id="conversation.google_generative_ai_conversation",
|
||||||
|
content="Garage door left open, do you want to close it?",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
with patch(
|
||||||
mock_chat = AsyncMock()
|
"google.genai.chats.AsyncChats.create", return_value=AsyncMock()
|
||||||
mock_create.return_value.send_message = mock_chat
|
) as mock_create:
|
||||||
chat_response = Mock(prompt_feedback=None)
|
mock_create.return_value.send_message_stream = mock_send_message_stream
|
||||||
mock_chat.return_value = chat_response
|
|
||||||
chat_response.candidates = [Mock(content=Mock(parts=[]))]
|
|
||||||
|
|
||||||
await conversation.async_converse(
|
await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
"hello",
|
"Hello",
|
||||||
chat_log.conversation_id,
|
mock_chat_log.conversation_id,
|
||||||
Context(),
|
context,
|
||||||
agent_id="conversation.google_generative_ai_conversation",
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
_, kwargs = mock_create.call_args
|
_, kwargs = mock_create.call_args
|
||||||
|
@ -11,7 +11,7 @@ from homeassistant.config_entries import ConfigEntryState
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ async def test_generate_content_service_error(
|
|||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"google.genai.models.AsyncModels.generate_content",
|
"google.genai.models.AsyncModels.generate_content",
|
||||||
side_effect=CLIENT_ERROR_500,
|
side_effect=API_ERROR_500,
|
||||||
),
|
),
|
||||||
pytest.raises(
|
pytest.raises(
|
||||||
HomeAssistantError,
|
HomeAssistantError,
|
||||||
@ -311,7 +311,7 @@ async def test_generate_content_service_with_image_not_exists(
|
|||||||
("side_effect", "state", "reauth"),
|
("side_effect", "state", "reauth"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
CLIENT_ERROR_500,
|
API_ERROR_500,
|
||||||
ConfigEntryState.SETUP_ERROR,
|
ConfigEntryState.SETUP_ERROR,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user