Swap the Gemini SDK to the newly released Unified SDK (#138246)

* Swapped the old GenAI client with the newly realeased one

* Fixed the Generate Content Action, Config Flow loading and code cleanup

* Add a function to mask the issues with Tools which start with Hass

* Fix most tests

* google-genai==1.1.0

* fixes

* Fixed the remaining tests

* Adressed comments

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: tronikos <tronikos@users.noreply.github.com>
This commit is contained in:
Ivan Lopez Hernandez 2025-02-21 22:41:05 -08:00 committed by GitHub
parent baa3b15dbc
commit 3160b7baa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 513 additions and 508 deletions

View File

@ -5,11 +5,10 @@ from __future__ import annotations
import mimetypes
from pathlib import Path
from google.ai import generativelanguage_v1beta
from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPIError
import google.generativeai as genai
import google.generativeai.types as genai_types
from google import genai # type: ignore[attr-defined]
from google.genai.errors import APIError, ClientError
from PIL import Image
from requests.exceptions import Timeout
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
@ -29,7 +28,13 @@ from homeassistant.exceptions import (
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL
from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
TIMEOUT_MILLIS,
)
SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
@ -37,6 +42,8 @@ CONF_IMAGE_FILENAME = "image_filename"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
type GoogleGenerativeAIConfigEntry = ConfigEntry[genai.Client]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Google Generative AI Conversation."""
@ -44,42 +51,47 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def generate_content(call: ServiceCall) -> ServiceResponse:
"""Generate content from text and optionally images."""
prompt_parts = [call.data[CONF_PROMPT]]
image_filenames = call.data[CONF_IMAGE_FILENAME]
for image_filename in image_filenames:
if not hass.config.is_allowed_path(image_filename):
raise HomeAssistantError(
f"Cannot read `{image_filename}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`"
)
if not Path(image_filename).exists():
raise HomeAssistantError(f"`{image_filename}` does not exist")
mime_type, _ = mimetypes.guess_type(image_filename)
if mime_type is None or not mime_type.startswith("image"):
raise HomeAssistantError(f"`{image_filename}` is not an image")
prompt_parts.append(
{
"mime_type": mime_type,
"data": await hass.async_add_executor_job(
Path(image_filename).read_bytes
),
}
)
model = genai.GenerativeModel(model_name=RECOMMENDED_CHAT_MODEL)
def append_images_to_prompt():
image_filenames = call.data[CONF_IMAGE_FILENAME]
for image_filename in image_filenames:
if not hass.config.is_allowed_path(image_filename):
raise HomeAssistantError(
f"Cannot read `{image_filename}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`"
)
if not Path(image_filename).exists():
raise HomeAssistantError(f"`{image_filename}` does not exist")
mime_type, _ = mimetypes.guess_type(image_filename)
if mime_type is None or not mime_type.startswith("image"):
raise HomeAssistantError(f"`{image_filename}` is not an image")
prompt_parts.append(Image.open(image_filename))
await hass.async_add_executor_job(append_images_to_prompt)
config_entry: GoogleGenerativeAIConfigEntry = hass.config_entries.async_entries(
DOMAIN
)[0]
client = config_entry.runtime_data
try:
response = await model.generate_content_async(prompt_parts)
response = await client.aio.models.generate_content(
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
)
except (
GoogleAPIError,
APIError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
raise HomeAssistantError(f"Error generating content: {err}") from err
if not response.parts:
raise HomeAssistantError("Error generating content")
if response.prompt_feedback:
raise HomeAssistantError(
f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}"
)
if not response.candidates[0].content.parts:
raise HomeAssistantError("Unknown error generating content")
return {"text": response.text}
@ -100,30 +112,34 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> bool:
"""Set up Google Generative AI Conversation from a config entry."""
genai.configure(api_key=entry.data[CONF_API_KEY])
try:
client = generativelanguage_v1beta.ModelServiceAsyncClient(
client_options=ClientOptions(api_key=entry.data[CONF_API_KEY])
client = genai.Client(api_key=entry.data[CONF_API_KEY])
await client.aio.models.get(
model=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)
await client.get_model(
name=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), timeout=5.0
)
except (GoogleAPIError, ValueError) as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
raise ConfigEntryAuthFailed(err) from err
if isinstance(err, DeadlineExceeded):
except (APIError, Timeout) as err:
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
raise ConfigEntryAuthFailed(err.message) from err
if isinstance(err, Timeout):
raise ConfigEntryNotReady(err) from err
raise ConfigEntryError(err) from err
else:
entry.runtime_data = client
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> bool:
"""Unload GoogleGenerativeAI."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False

View File

@ -3,15 +3,13 @@
from __future__ import annotations
from collections.abc import Mapping
from functools import partial
import logging
from types import MappingProxyType
from typing import Any
from google.ai import generativelanguage_v1beta
from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import ClientError, GoogleAPIError
import google.generativeai as genai
from google import genai # type: ignore[attr-defined]
from google.genai.errors import APIError, ClientError
from requests.exceptions import Timeout
import voluptuous as vol
from homeassistant.config_entries import (
@ -53,6 +51,7 @@ from .const import (
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
TIMEOUT_MILLIS,
)
_LOGGER = logging.getLogger(__name__)
@ -70,15 +69,20 @@ RECOMMENDED_OPTIONS = {
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
async def validate_input(data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
"""
client = generativelanguage_v1beta.ModelServiceAsyncClient(
client_options=ClientOptions(api_key=data[CONF_API_KEY])
client = genai.Client(api_key=data[CONF_API_KEY])
await client.aio.models.list(
config={
"http_options": {
"timeout": TIMEOUT_MILLIS,
},
"query_base": True,
}
)
await client.list_models(timeout=5.0)
class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
@ -93,9 +97,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {}
if user_input is not None:
try:
await validate_input(self.hass, user_input)
except GoogleAPIError as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
await validate_input(user_input)
except (APIError, Timeout) as err:
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
errors["base"] = "invalid_auth"
else:
errors["base"] = "cannot_connect"
@ -166,6 +170,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)
self._genai_client = config_entry.runtime_data
async def async_step_init(
self, user_input: dict[str, Any] | None = None
@ -188,7 +193,9 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = await google_generative_ai_config_option_schema(self.hass, options)
schema = await google_generative_ai_config_option_schema(
self.hass, options, self._genai_client
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
@ -198,6 +205,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
async def google_generative_ai_config_option_schema(
hass: HomeAssistant,
options: dict[str, Any] | MappingProxyType[str, Any],
genai_client: genai.Client,
) -> dict:
"""Return a schema for Google Generative AI completion options."""
hass_apis: list[SelectOptionDict] = [
@ -236,18 +244,21 @@ async def google_generative_ai_config_option_schema(
if options.get(CONF_RECOMMENDED):
return schema
api_models = await hass.async_add_executor_job(partial(genai.list_models))
api_models_pager = await genai_client.aio.models.list(config={"query_base": True})
api_models = [api_model async for api_model in api_models_pager]
models = [
SelectOptionDict(
label=api_model.display_name,
value=api_model.name,
)
for api_model in sorted(api_models, key=lambda x: x.display_name)
for api_model in sorted(api_models, key=lambda x: x.display_name or "")
if (
api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro
and api_model.display_name
and api_model.name
and api_model.supported_actions
and "vision" not in api_model.name
and "generateContent" in api_model.supported_generation_methods
and "generateContent" in api_model.supported_actions
)
]

View File

@ -22,3 +22,5 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
TIMEOUT_MILLIS = 10000

View File

@ -6,11 +6,18 @@ import codecs
from collections.abc import Callable
from typing import Any, Literal, cast
from google.api_core.exceptions import GoogleAPIError
import google.generativeai as genai
from google.generativeai import protos
import google.generativeai.types as genai_types
from google.protobuf.json_format import MessageToDict
from google.genai.errors import APIError
from google.genai.types import (
AutomaticFunctionCallingConfig,
Content,
FunctionDeclaration,
GenerateContentConfig,
HarmCategory,
Part,
SafetySetting,
Schema,
Tool,
)
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
@ -57,21 +64,40 @@ async def async_setup_entry(
SUPPORTED_SCHEMA_KEYS = {
"type",
"format",
"description",
"min_items",
"example",
"property_ordering",
"pattern",
"minimum",
"default",
"any_of",
"max_length",
"title",
"min_properties",
"min_length",
"max_items",
"maximum",
"nullable",
"max_properties",
"type",
"description",
"enum",
"format",
"items",
"properties",
"required",
}
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Format the schema to protobuf."""
if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")):
for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
def _camel_to_snake(name: str) -> str:
"""Convert camel case to snake case."""
return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_")
def _format_schema(schema: dict[str, Any]) -> Schema:
"""Format the schema to be compatible with Gemini API."""
if subschemas := schema.get("allOf"):
for subschema in subschemas: # Gemini API does not support allOf keys
if "type" in subschema: # Fallback to first subschema with 'type' field
return _format_schema(subschema)
return _format_schema(
@ -80,42 +106,38 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
result = {}
for key, val in schema.items():
key = _camel_to_snake(key)
if key not in SUPPORTED_SCHEMA_KEYS:
continue
if key == "any_of":
val = [_format_schema(subschema) for subschema in val]
if key == "type":
key = "type_"
val = val.upper()
elif key == "format":
if schema.get("type") == "string" and val != "enum":
continue
if schema.get("type") not in ("number", "integer", "string"):
continue
key = "format_"
elif key == "items":
if key == "items":
val = _format_schema(val)
elif key == "properties":
val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val
if result.get("enum") and result.get("type_") != "STRING":
if result.get("enum") and result.get("type") != "STRING":
# enum is only allowed for STRING type. This is safe as long as the schema
# contains vol.Coerce for the respective type, for example:
# vol.All(vol.Coerce(int), vol.In([1, 2, 3]))
result["type_"] = "STRING"
result["type"] = "STRING"
result["enum"] = [str(item) for item in result["enum"]]
if result.get("type_") == "OBJECT" and not result.get("properties"):
if result.get("type") == "OBJECT" and not result.get("properties"):
# An object with undefined properties is not supported by Gemini API.
# Fallback to JSON string. This will probably fail for most tools that want it,
# but we don't have a better fallback strategy so far.
result["properties"] = {"json": {"type_": "STRING"}}
result["properties"] = {"json": {"type": "STRING"}}
result["required"] = []
return result
return cast(Schema, result)
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
) -> Tool:
"""Format tool specification."""
if tool.parameters.schema:
@ -125,16 +147,14 @@ def _format_tool(
else:
parameters = None
return protos.Tool(
{
"function_declarations": [
{
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
]
}
return Tool(
function_declarations=[
FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=parameters,
)
]
)
@ -151,14 +171,12 @@ def _escape_decode(value: Any) -> Any:
def _create_google_tool_response_content(
content: list[conversation.ToolResultContent],
) -> protos.Content:
) -> Content:
"""Create a Google tool response content."""
return protos.Content(
return Content(
parts=[
protos.Part(
function_response=protos.FunctionResponse(
name=tool_result.tool_name, response=tool_result.tool_result
)
Part.from_function_response(
name=tool_result.tool_name, response=tool_result.tool_result
)
for tool_result in content
]
@ -169,33 +187,36 @@ def _convert_content(
content: conversation.UserContent
| conversation.AssistantContent
| conversation.SystemContent,
) -> genai_types.ContentDict:
) -> Content:
"""Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = "model" if content.role == "assistant" else content.role
return {"role": role, "parts": content.content}
return Content(
role=role,
parts=[
Part.from_text(text=content.content if content.content else ""),
],
)
# Handle the Assistant content with tool calls.
assert type(content) is conversation.AssistantContent
parts = []
parts: list[Part] = []
if content.content:
parts.append(protos.Part(text=content.content))
parts.append(Part.from_text(text=content.content))
if content.tool_calls:
parts.extend(
[
protos.Part(
function_call=protos.FunctionCall(
name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args),
)
Part.from_function_call(
name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args),
)
for tool_call in content.tool_calls
]
)
return protos.Content({"role": "model", "parts": parts})
return Content(role="model", parts=parts)
class GoogleGenerativeAIConversationEntity(
@ -209,6 +230,7 @@ class GoogleGenerativeAIConversationEntity(
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self._genai_client = entry.runtime_data
self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
@ -273,7 +295,7 @@ class GoogleGenerativeAIConversationEntity(
except conversation.ConverseError as err:
return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None
tools: list[Tool | Callable[..., Any]] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
@ -288,13 +310,22 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
)
prompt = chat_log.content[0].content # type: ignore[union-attr]
messages: list[genai_types.ContentDict] = []
prompt_content = cast(
conversation.SystemContent,
chat_log.content[0],
)
if prompt_content.content:
prompt = prompt_content.content
else:
raise HomeAssistantError("Invalid prompt content")
messages: list[Content] = []
# Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = []
for chat_content in chat_log.content[1:]:
for chat_content in chat_log.content[1:-1]:
if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content))
@ -317,85 +348,93 @@ class GoogleGenerativeAIConversationEntity(
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
model = genai.GenerativeModel(
model_name=model_name,
generation_config={
"temperature": self.entry.options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
generateContentConfig = GenerateContentConfig(
temperature=self.entry.options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
),
top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
max_output_tokens=self.entry.options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=self.entry.options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
"top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
"max_output_tokens": self.entry.options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=self.entry.options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
),
},
safety_settings={
"HARASSMENT": self.entry.options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=self.entry.options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
"HATE": self.entry.options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=self.entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
"SEXUAL": self.entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
"DANGEROUS": self.entry.options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
},
],
tools=tools or None,
system_instruction=prompt if supports_system_instruction else None,
automatic_function_calling=AutomaticFunctionCallingConfig(
disable=True, maximum_remote_calls=None
),
)
if not supports_system_instruction:
messages = [
{"role": "user", "parts": prompt},
{"role": "model", "parts": "Ok"},
Content(role="user", parts=[Part.from_text(text=prompt)]),
Content(role="model", parts=[Part.from_text(text="Ok")]),
*messages,
]
chat = model.start_chat(history=messages)
chat_request = user_input.text
chat = self._genai_client.aio.chats.create(
model=model_name, history=messages, config=generateContentConfig
)
chat_request: str | Content = user_input.text
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
chat_response = await chat.send_message_async(chat_request)
except (
GoogleAPIError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
LOGGER.error("Error sending message: %s %s", type(err), err)
chat_response = await chat.send_message(message=chat_request)
if isinstance(
err, genai_types.StopCandidateException
) and "finish_reason: SAFETY\n" in str(err):
error = "The message got blocked by your safety settings"
else:
error = (
f"Sorry, I had a problem talking to Google Generative AI: {err}"
if chat_response.prompt_feedback:
raise HomeAssistantError(
f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}"
)
except (
APIError,
ValueError,
) as err:
LOGGER.error("Error sending message: %s %s", type(err), err)
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
raise HomeAssistantError(error) from err
LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
response_parts = chat_response.candidates[0].content.parts
if not response_parts:
raise HomeAssistantError(
"Sorry, I had a problem getting a response from Google Generative AI."
)
content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text]
[part.text.strip() for part in response_parts if part.text]
)
tool_calls = []
for part in chat_response.parts:
for part in response_parts:
if not part.function_call:
continue
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = _escape_decode(tool_call["args"])
tool_call = part.function_call
tool_name = tool_call.name
tool_args = _escape_decode(tool_call.args)
tool_calls.append(
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
)
@ -418,7 +457,7 @@ class GoogleGenerativeAIConversationEntity(
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text])
" ".join([part.text.strip() for part in response_parts if part.text])
)
return conversation.ConversationResult(
response=response, conversation_id=chat_log.conversation_id

View File

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["google-generativeai==0.8.2"]
"requirements": ["google-genai==1.1.0"]
}

2
requirements_all.txt generated
View File

@ -1033,7 +1033,7 @@ google-cloud-speech==2.27.0
google-cloud-texttospeech==2.17.2
# homeassistant.components.google_generative_ai_conversation
google-generativeai==0.8.2
google-genai==1.1.0
# homeassistant.components.nest
google-nest-sdm==7.1.3

View File

@ -883,7 +883,7 @@ google-cloud-speech==2.27.0
google-cloud-texttospeech==2.17.2
# homeassistant.components.google_generative_ai_conversation
google-generativeai==0.8.2
google-genai==1.1.0
# homeassistant.components.nest
google-nest-sdm==7.1.3

View File

@ -1 +1,31 @@
"""Tests for the Google Generative AI Conversation integration."""
from unittest.mock import Mock
from google.genai.errors import ClientError
import requests
CLIENT_ERROR_500 = ClientError(
500,
Mock(
__class__=requests.Response,
json=Mock(
return_value={
"message": "Internal Server Error",
"status": "internal-error",
}
),
),
)
CLIENT_ERROR_API_KEY_INVALID = ClientError(
400,
Mock(
__class__=requests.Response,
json=Mock(
return_value={
"message": "'reason': API_KEY_INVALID",
"status": "unauthorized",
}
),
),
)

View File

@ -1,7 +1,6 @@
"""Tests helpers."""
from collections.abc import Generator
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
@ -15,14 +14,7 @@ from tests.common import MockConfigEntry
@pytest.fixture
def mock_genai() -> Generator[None]:
"""Mock the genai call in async_setup_entry."""
with patch("google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.get_model"):
yield
@pytest.fixture
def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry:
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Mock a config entry."""
entry = MockConfigEntry(
domain="google_generative_ai_conversation",
@ -31,18 +23,21 @@ def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry:
"api_key": "bla",
},
)
entry.runtime_data = Mock()
entry.add_to_hass(hass)
return entry
@pytest.fixture
def mock_config_entry_with_assist(
async def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
await hass.async_block_till_done()
return mock_config_entry
@ -51,8 +46,11 @@ async def mock_init_component(
hass: HomeAssistant, mock_config_entry: ConfigEntry
) -> None:
"""Initialize integration."""
assert await async_setup_component(hass, "google_generative_ai_conversation", {})
await hass.async_block_till_done()
with patch("google.genai.models.AsyncModels.get"):
assert await async_setup_component(
hass, "google_generative_ai_conversation", {}
)
await hass.async_block_till_done()
@pytest.fixture(autouse=True)

View File

@ -6,106 +6,26 @@
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-2.0-flash',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'tools': list([
function_declarations {
name: "test_tool"
description: "Test function"
parameters {
type_: OBJECT
properties {
key: "param3"
value {
type_: OBJECT
properties {
key: "json"
value {
type_: STRING
}
}
}
}
properties {
key: "param2"
value {
type_: NUMBER
}
}
properties {
key: "param1"
value {
type_: ARRAY
description: "Test parameters"
items {
type_: STRING
}
}
}
}
}
,
]),
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.ARRAY: 'ARRAY'>, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format='lower', items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=[Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.NUMBER: 'NUMBER'>, description=None, enum=None, format=None, items=None, properties=None, required=None), Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.INTEGER: 'INTEGER'>, description=None, enum=None, format=None, items=None, properties=None, required=None)], max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]),
'model': 'models/gemini-2.0-flash',
}),
),
tuple(
'().start_chat().send_message_async',
'().send_message',
tuple(
'Please call the test function',
),
dict({
'message': 'Please call the test function',
}),
),
tuple(
'().start_chat().send_message_async',
'().send_message',
tuple(
parts {
function_response {
name: "test_tool"
response {
fields {
key: "result"
value {
string_value: "Test response"
}
}
}
}
}
,
),
dict({
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
}),
),
])
@ -117,75 +37,26 @@
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-2.0-flash',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'tools': list([
function_declarations {
name: "test_tool"
description: "Test function"
}
,
]),
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]),
'model': 'models/gemini-2.0-flash',
}),
),
tuple(
'().start_chat().send_message_async',
'().send_message',
tuple(
'Please call the test function',
),
dict({
'message': 'Please call the test function',
}),
),
tuple(
'().start_chat().send_message_async',
'().send_message',
tuple(
parts {
function_response {
name: "test_tool"
response {
fields {
key: "result"
value {
string_value: "Test response"
}
}
}
}
}
,
),
dict({
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
}),
),
])

View File

@ -6,21 +6,11 @@
tuple(
),
dict({
'model_name': 'models/gemini-2.0-flash',
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'contents': list([
'Describe this image from my doorbell camera',
dict({
'data': b'image bytes',
'mime_type': 'image/jpeg',
}),
b'image bytes',
]),
),
dict({
'model': 'models/gemini-2.0-flash',
}),
),
])
@ -32,17 +22,10 @@
tuple(
),
dict({
'model_name': 'models/gemini-2.0-flash',
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'contents': list([
'Write an opening speech for a Home Assistant release party',
]),
),
dict({
'model': 'models/gemini-2.0-flash',
}),
),
])

View File

@ -1,10 +1,9 @@
"""Test the Google Generative AI Conversation config flow."""
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock, patch
from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module
import pytest
from requests.exceptions import Timeout
from homeassistant import config_entries
from homeassistant.components.google_generative_ai_conversation.config_flow import (
@ -33,6 +32,8 @@ from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry
@ -41,30 +42,37 @@ def mock_models():
"""Mock the model list API."""
model_20_flash = Mock(
display_name="Gemini 2.0 Flash",
supported_generation_methods=["generateContent"],
supported_actions=["generateContent"],
)
model_20_flash.name = "models/gemini-2.0-flash"
model_15_flash = Mock(
display_name="Gemini 1.5 Flash",
supported_generation_methods=["generateContent"],
supported_actions=["generateContent"],
)
model_15_flash.name = "models/gemini-1.5-flash-latest"
model_15_pro = Mock(
display_name="Gemini 1.5 Pro",
supported_generation_methods=["generateContent"],
supported_actions=["generateContent"],
)
model_15_pro.name = "models/gemini-1.5-pro-latest"
model_10_pro = Mock(
display_name="Gemini 1.0 Pro",
supported_generation_methods=["generateContent"],
supported_actions=["generateContent"],
)
model_10_pro.name = "models/gemini-pro"
async def models_pager():
yield model_20_flash
yield model_15_flash
yield model_15_pro
yield model_10_pro
with patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
return_value=iter([model_20_flash, model_15_flash, model_15_pro, model_10_pro]),
"google.genai.models.AsyncModels.list",
return_value=models_pager(),
):
yield
@ -86,7 +94,7 @@ async def test_form(hass: HomeAssistant) -> None:
with (
patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models",
"google.genai.models.AsyncModels.list",
),
patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
@ -170,7 +178,11 @@ async def test_options_switching(
expected_options,
) -> None:
"""Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry(
mock_config_entry, options=current_options
)
await hass.async_block_till_done()
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
@ -195,17 +207,15 @@ async def test_options_switching(
("side_effect", "error"),
[
(
ClientError("some error"),
CLIENT_ERROR_500,
"cannot_connect",
),
(
DeadlineExceeded("deadline exceeded"),
Timeout("deadline exceeded"),
"cannot_connect",
),
(
ClientError(
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
),
CLIENT_ERROR_API_KEY_INVALID,
"invalid_auth",
),
(Exception, "unknown"),
@ -217,12 +227,7 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_client = AsyncMock()
mock_client.list_models.side_effect = side_effect
with patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient",
return_value=mock_client,
):
with patch("google.genai.models.AsyncModels.list", side_effect=side_effect):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
@ -259,7 +264,7 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
with (
patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models",
"google.genai.models.AsyncModels.list",
),
patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",

View File

@ -1,12 +1,10 @@
"""Tests for the Google Generative AI Conversation integration conversation platform."""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, Mock, patch
from freezegun import freeze_time
from google.ai.generativelanguage_v1beta.types.content import FunctionCall
from google.api_core.exceptions import GoogleAPIError
import google.generativeai.types as genai_types
from google.genai.types import FunctionCall
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
@ -22,6 +20,8 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from . import CLIENT_ERROR_500
from tests.common import MockConfigEntry
@ -51,7 +51,7 @@ async def test_function_call(
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling."""
agent_id = mock_config_entry_with_assist.entry_id
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
@ -69,12 +69,12 @@ async def test_function_call(
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(
name="test_tool",
@ -92,7 +92,7 @@ async def test_function_call(
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse(
hass,
"Please call the test function",
@ -104,20 +104,28 @@ async def test_function_call(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
mock_tool_call = mock_create.mock_calls[2][2]["message"]
assert mock_tool_call.model_dump() == {
"parts": [
{
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"result": "Test response",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
},
],
"role": "",
"role": None,
}
mock_tool.async_call.assert_awaited_once_with(
@ -139,7 +147,7 @@ async def test_function_call(
device_id="test_device",
),
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
# Test conversating tracing
traces = trace.async_get_traces()
@ -170,7 +178,7 @@ async def test_function_call_without_parameters(
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling without parameters."""
agent_id = mock_config_entry_with_assist.entry_id
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
@ -180,12 +188,12 @@ async def test_function_call_without_parameters(
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={})
@ -197,7 +205,7 @@ async def test_function_call_without_parameters(
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse(
hass,
"Please call the test function",
@ -209,20 +217,28 @@ async def test_function_call_without_parameters(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
mock_tool_call = mock_create.mock_calls[2][2]["message"]
assert mock_tool_call.model_dump() == {
"parts": [
{
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"result": "Test response",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
},
],
"role": "",
"role": None,
}
mock_tool.async_call.assert_awaited_once_with(
@ -241,7 +257,7 @@ async def test_function_call_without_parameters(
device_id="test_device",
),
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
@patch(
@ -254,7 +270,7 @@ async def test_function_exception(
mock_config_entry_with_assist: MockConfigEntry,
) -> None:
"""Test exception in function calling."""
agent_id = mock_config_entry_with_assist.entry_id
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
@ -270,12 +286,12 @@ async def test_function_exception(
mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model:
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
@ -287,7 +303,7 @@ async def test_function_exception(
raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part]
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse(
hass,
"Please call the test function",
@ -299,21 +315,29 @@ async def test_function_exception(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
assert mock_tool_call == {
mock_tool_call = mock_create.mock_calls[2][2]["message"]
assert mock_tool_call.model_dump() == {
"parts": [
{
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": {
"id": None,
"name": "test_tool",
"response": {
"error": "HomeAssistantError",
"error_text": "Test tool exception",
},
},
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
},
],
"role": "",
"role": None,
}
mock_tool.async_call.assert_awaited_once_with(
hass,
@ -338,18 +362,22 @@ async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that client errors are caught."""
with patch("google.generativeai.GenerativeModel") as mock_model:
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = GoogleAPIError("some error")
mock_create.return_value.send_message = mock_chat
mock_chat.side_effect = CLIENT_ERROR_500
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI: some error"
"Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}"
)
@ -358,20 +386,24 @@ async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test blocked response."""
with patch("google.generativeai.GenerativeModel") as mock_model:
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
"finish_reason: SAFETY\n"
)
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY"))
mock_chat.return_value = chat_response
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"The message got blocked by your safety settings"
"The message got blocked due to content violations, reason: SAFETY"
)
@ -380,14 +412,18 @@ async def test_empty_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test empty response."""
with patch("google.generativeai.GenerativeModel") as mock_model:
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = []
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
chat_response.candidates = [Mock(content=Mock(parts=[]))]
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@ -402,17 +438,19 @@ async def test_converse_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test handling ChatLog raising ConverseError."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
)
with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry(
mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=mock_config_entry.entry_id,
agent_id="conversation.google_generative_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@ -449,31 +487,39 @@ async def test_escape_decode() -> None:
@pytest.mark.parametrize(
("openapi", "protobuf"),
("openapi", "genai_schema"),
[
(
{"type": "string", "enum": ["a", "b", "c"]},
{"type_": "STRING", "enum": ["a", "b", "c"]},
{"type": "STRING", "enum": ["a", "b", "c"]},
),
(
{"type": "integer", "enum": [1, 2, 3]},
{"type_": "STRING", "enum": ["1", "2", "3"]},
{"type": "STRING", "enum": ["1", "2", "3"]},
),
(
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
),
({"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"type_": "INTEGER"}),
(
{
"anyOf": [
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
"any_of": [
{"any_of": [{"type": "integer"}, {"type": "number"}]},
{"any_of": [{"type": "integer"}, {"type": "number"}]},
]
},
{
"any_of": [
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
]
},
{"type_": "INTEGER"},
),
({"type": "string", "format": "lower"}, {"type_": "STRING"}),
({"type": "boolean", "format": "bool"}, {"type_": "BOOLEAN"}),
({"type": "string", "format": "lower"}, {"format": "lower", "type": "STRING"}),
({"type": "boolean", "format": "bool"}, {"format": "bool", "type": "BOOLEAN"}),
(
{"type": "number", "format": "percent"},
{"type_": "NUMBER", "format_": "percent"},
{"type": "NUMBER", "format": "percent"},
),
(
{
@ -482,25 +528,25 @@ async def test_escape_decode() -> None:
"required": [],
},
{
"type_": "OBJECT",
"properties": {"var": {"type_": "STRING"}},
"type": "OBJECT",
"properties": {"var": {"type": "STRING"}},
"required": [],
},
),
(
{"type": "object", "additionalProperties": True},
{
"type_": "OBJECT",
"properties": {"json": {"type_": "STRING"}},
"type": "OBJECT",
"properties": {"json": {"type": "STRING"}},
"required": [],
},
),
(
{"type": "array", "items": {"type": "string"}},
{"type_": "ARRAY", "items": {"type_": "STRING"}},
{"type": "ARRAY", "items": {"type": "STRING"}},
),
],
)
async def test_format_schema(openapi, protobuf) -> None:
async def test_format_schema(openapi, genai_schema) -> None:
"""Test _format_schema."""
assert _format_schema(openapi) == protobuf
assert _format_schema(openapi) == genai_schema

View File

@ -1,16 +1,17 @@
"""Tests for the Google Generative AI Conversation integration."""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, Mock, patch
from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module
import pytest
from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry
@ -24,12 +25,14 @@ async def test_generate_content_service_without_images(
"party for the latest version of Home Assistant!"
)
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_response = MagicMock()
mock_response.text = stubbed_generated_content
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
)
with patch(
"google.genai.models.AsyncModels.generate_content",
return_value=Mock(
text=stubbed_generated_content,
prompt_feedback=None,
candidates=[Mock()],
),
) as mock_generate:
response = await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
@ -41,7 +44,7 @@ async def test_generate_content_service_without_images(
assert response == {
"text": stubbed_generated_content,
}
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component")
@ -54,19 +57,21 @@ async def test_generate_content_service_with_image(
)
with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch(
"homeassistant.components.google_generative_ai_conversation.Path.read_bytes",
"google.genai.models.AsyncModels.generate_content",
return_value=Mock(
text=stubbed_generated_content,
prompt_feedback=None,
candidates=[Mock()],
),
) as mock_generate,
patch(
"homeassistant.components.google_generative_ai_conversation.Image.open",
return_value=b"image bytes",
),
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
):
mock_response = MagicMock()
mock_response.text = stubbed_generated_content
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
)
response = await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
@ -81,7 +86,7 @@ async def test_generate_content_service_with_image(
assert response == {
"text": stubbed_generated_content,
}
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component")
@ -90,20 +95,23 @@ async def test_generate_content_service_error(
mock_config_entry: MockConfigEntry,
) -> None:
"""Test generate content service handles errors."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_model.return_value.generate_content_async = AsyncMock(
side_effect=ClientError("reason")
with (
patch(
"google.genai.models.AsyncModels.generate_content",
side_effect=CLIENT_ERROR_500,
),
pytest.raises(
HomeAssistantError,
match="Error generating content: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}",
),
):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "write a story about an epic fail"},
blocking=True,
return_response=True,
)
with pytest.raises(
HomeAssistantError, match="Error generating content: None reason"
):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "write a story about an epic fail"},
blocking=True,
return_response=True,
)
@pytest.mark.usefixtures("mock_init_component")
@ -113,21 +121,22 @@ async def test_generate_content_response_has_empty_parts(
) -> None:
"""Test generate content service handles response with empty parts."""
with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch(
"google.genai.models.AsyncModels.generate_content",
return_value=Mock(
prompt_feedback=None,
candidates=[Mock(content=Mock(parts=[]))],
),
),
pytest.raises(HomeAssistantError, match="Unknown error generating content"),
):
mock_response = MagicMock()
mock_response.parts = []
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "write a story about an epic fail"},
blocking=True,
return_response=True,
)
with pytest.raises(HomeAssistantError, match="Error generating content"):
await hass.services.async_call(
"google_generative_ai_conversation",
"generate_content",
{"prompt": "write a story about an epic fail"},
blocking=True,
return_response=True,
)
@pytest.mark.usefixtures("mock_init_component")
@ -211,19 +220,17 @@ async def test_generate_content_service_with_non_image(hass: HomeAssistant) -> N
("side_effect", "state", "reauth"),
[
(
ClientError("some error"),
CLIENT_ERROR_500,
ConfigEntryState.SETUP_ERROR,
False,
),
(
DeadlineExceeded("deadline exceeded"),
Timeout,
ConfigEntryState.SETUP_RETRY,
False,
),
(
ClientError(
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
),
CLIENT_ERROR_API_KEY_INVALID,
ConfigEntryState.SETUP_ERROR,
True,
),
@ -235,10 +242,7 @@ async def test_config_entry_error(
"""Test different configuration entry errors."""
mock_client = AsyncMock()
mock_client.get_model.side_effect = side_effect
with patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient",
return_value=mock_client,
):
with patch("google.genai.models.AsyncModels.get", side_effect=side_effect):
assert not await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_config_entry.state == state