Swap the Gemini SDK to the newly released Unified SDK (#138246)

* Swapped the old GenAI client with the newly realeased one

* Fixed the Generate Content Action, Config Flow loading and code cleanup

* Add a function to mask the issues with Tools which start with Hass

* Fix most tests

* google-genai==1.1.0

* fixes

* Fixed the remaining tests

* Adressed comments

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: tronikos <tronikos@users.noreply.github.com>
This commit is contained in:
Ivan Lopez Hernandez 2025-02-21 22:41:05 -08:00 committed by GitHub
parent baa3b15dbc
commit 3160b7baa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 513 additions and 508 deletions

View File

@ -5,11 +5,10 @@ from __future__ import annotations
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from google.ai import generativelanguage_v1beta from google import genai # type: ignore[attr-defined]
from google.api_core.client_options import ClientOptions from google.genai.errors import APIError, ClientError
from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPIError from PIL import Image
import google.generativeai as genai from requests.exceptions import Timeout
import google.generativeai.types as genai_types
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -29,7 +28,13 @@ from homeassistant.exceptions import (
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
TIMEOUT_MILLIS,
)
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename" CONF_IMAGE_FILENAME = "image_filename"
@ -37,6 +42,8 @@ CONF_IMAGE_FILENAME = "image_filename"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,) PLATFORMS = (Platform.CONVERSATION,)
type GoogleGenerativeAIConfigEntry = ConfigEntry[genai.Client]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Google Generative AI Conversation.""" """Set up Google Generative AI Conversation."""
@ -44,6 +51,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def generate_content(call: ServiceCall) -> ServiceResponse: async def generate_content(call: ServiceCall) -> ServiceResponse:
"""Generate content from text and optionally images.""" """Generate content from text and optionally images."""
prompt_parts = [call.data[CONF_PROMPT]] prompt_parts = [call.data[CONF_PROMPT]]
def append_images_to_prompt():
image_filenames = call.data[CONF_IMAGE_FILENAME] image_filenames = call.data[CONF_IMAGE_FILENAME]
for image_filename in image_filenames: for image_filename in image_filenames:
if not hass.config.is_allowed_path(image_filename): if not hass.config.is_allowed_path(image_filename):
@ -57,29 +66,32 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
mime_type, _ = mimetypes.guess_type(image_filename) mime_type, _ = mimetypes.guess_type(image_filename)
if mime_type is None or not mime_type.startswith("image"): if mime_type is None or not mime_type.startswith("image"):
raise HomeAssistantError(f"`{image_filename}` is not an image") raise HomeAssistantError(f"`{image_filename}` is not an image")
prompt_parts.append( prompt_parts.append(Image.open(image_filename))
{
"mime_type": mime_type,
"data": await hass.async_add_executor_job(
Path(image_filename).read_bytes
),
}
)
model = genai.GenerativeModel(model_name=RECOMMENDED_CHAT_MODEL) await hass.async_add_executor_job(append_images_to_prompt)
config_entry: GoogleGenerativeAIConfigEntry = hass.config_entries.async_entries(
DOMAIN
)[0]
client = config_entry.runtime_data
try: try:
response = await model.generate_content_async(prompt_parts) response = await client.aio.models.generate_content(
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
)
except ( except (
GoogleAPIError, APIError,
ValueError, ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err: ) as err:
raise HomeAssistantError(f"Error generating content: {err}") from err raise HomeAssistantError(f"Error generating content: {err}") from err
if not response.parts: if response.prompt_feedback:
raise HomeAssistantError("Error generating content") raise HomeAssistantError(
f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}"
)
if not response.candidates[0].content.parts:
raise HomeAssistantError("Unknown error generating content")
return {"text": response.text} return {"text": response.text}
@ -100,30 +112,34 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> bool:
"""Set up Google Generative AI Conversation from a config entry.""" """Set up Google Generative AI Conversation from a config entry."""
genai.configure(api_key=entry.data[CONF_API_KEY])
try: try:
client = generativelanguage_v1beta.ModelServiceAsyncClient( client = genai.Client(api_key=entry.data[CONF_API_KEY])
client_options=ClientOptions(api_key=entry.data[CONF_API_KEY]) await client.aio.models.get(
model=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
) )
await client.get_model( except (APIError, Timeout) as err:
name=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), timeout=5.0 if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
) raise ConfigEntryAuthFailed(err.message) from err
except (GoogleAPIError, ValueError) as err: if isinstance(err, Timeout):
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
raise ConfigEntryAuthFailed(err) from err
if isinstance(err, DeadlineExceeded):
raise ConfigEntryNotReady(err) from err raise ConfigEntryNotReady(err) from err
raise ConfigEntryError(err) from err raise ConfigEntryError(err) from err
else:
entry.runtime_data = client
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> bool:
"""Unload GoogleGenerativeAI.""" """Unload GoogleGenerativeAI."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False return False

View File

@ -3,15 +3,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any
from google.ai import generativelanguage_v1beta from google import genai # type: ignore[attr-defined]
from google.api_core.client_options import ClientOptions from google.genai.errors import APIError, ClientError
from google.api_core.exceptions import ClientError, GoogleAPIError from requests.exceptions import Timeout
import google.generativeai as genai
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ( from homeassistant.config_entries import (
@ -53,6 +51,7 @@ from .const import (
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K, RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
TIMEOUT_MILLIS,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -70,15 +69,20 @@ RECOMMENDED_OPTIONS = {
} }
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
""" """
client = generativelanguage_v1beta.ModelServiceAsyncClient( client = genai.Client(api_key=data[CONF_API_KEY])
client_options=ClientOptions(api_key=data[CONF_API_KEY]) await client.aio.models.list(
config={
"http_options": {
"timeout": TIMEOUT_MILLIS,
},
"query_base": True,
}
) )
await client.list_models(timeout=5.0)
class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
@ -93,9 +97,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
try: try:
await validate_input(self.hass, user_input) await validate_input(user_input)
except GoogleAPIError as err: except (APIError, Timeout) as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
else: else:
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
@ -166,6 +170,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
self.last_rendered_recommended = config_entry.options.get( self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False CONF_RECOMMENDED, False
) )
self._genai_client = config_entry.runtime_data
async def async_step_init( async def async_step_init(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -188,7 +193,9 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
} }
schema = await google_generative_ai_config_option_schema(self.hass, options) schema = await google_generative_ai_config_option_schema(
self.hass, options, self._genai_client
)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
@ -198,6 +205,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
async def google_generative_ai_config_option_schema( async def google_generative_ai_config_option_schema(
hass: HomeAssistant, hass: HomeAssistant,
options: dict[str, Any] | MappingProxyType[str, Any], options: dict[str, Any] | MappingProxyType[str, Any],
genai_client: genai.Client,
) -> dict: ) -> dict:
"""Return a schema for Google Generative AI completion options.""" """Return a schema for Google Generative AI completion options."""
hass_apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
@ -236,18 +244,21 @@ async def google_generative_ai_config_option_schema(
if options.get(CONF_RECOMMENDED): if options.get(CONF_RECOMMENDED):
return schema return schema
api_models = await hass.async_add_executor_job(partial(genai.list_models)) api_models_pager = await genai_client.aio.models.list(config={"query_base": True})
api_models = [api_model async for api_model in api_models_pager]
models = [ models = [
SelectOptionDict( SelectOptionDict(
label=api_model.display_name, label=api_model.display_name,
value=api_model.name, value=api_model.name,
) )
for api_model in sorted(api_models, key=lambda x: x.display_name) for api_model in sorted(api_models, key=lambda x: x.display_name or "")
if ( if (
api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro
and api_model.display_name
and api_model.name
and api_model.supported_actions
and "vision" not in api_model.name and "vision" not in api_model.name
and "generateContent" in api_model.supported_generation_methods and "generateContent" in api_model.supported_actions
) )
] ]

View File

@ -22,3 +22,5 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold" CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold" CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE" RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
TIMEOUT_MILLIS = 10000

View File

@ -6,11 +6,18 @@ import codecs
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Literal, cast from typing import Any, Literal, cast
from google.api_core.exceptions import GoogleAPIError from google.genai.errors import APIError
import google.generativeai as genai from google.genai.types import (
from google.generativeai import protos AutomaticFunctionCallingConfig,
import google.generativeai.types as genai_types Content,
from google.protobuf.json_format import MessageToDict FunctionDeclaration,
GenerateContentConfig,
HarmCategory,
Part,
SafetySetting,
Schema,
Tool,
)
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
@ -57,21 +64,40 @@ async def async_setup_entry(
SUPPORTED_SCHEMA_KEYS = { SUPPORTED_SCHEMA_KEYS = {
"type", "min_items",
"format", "example",
"description", "property_ordering",
"pattern",
"minimum",
"default",
"any_of",
"max_length",
"title",
"min_properties",
"min_length",
"max_items",
"maximum",
"nullable", "nullable",
"max_properties",
"type",
"description",
"enum", "enum",
"format",
"items", "items",
"properties", "properties",
"required", "required",
} }
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: def _camel_to_snake(name: str) -> str:
"""Format the schema to protobuf.""" """Convert camel case to snake case."""
if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")): return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_")
for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
def _format_schema(schema: dict[str, Any]) -> Schema:
"""Format the schema to be compatible with Gemini API."""
if subschemas := schema.get("allOf"):
for subschema in subschemas: # Gemini API does not support allOf keys
if "type" in subschema: # Fallback to first subschema with 'type' field if "type" in subschema: # Fallback to first subschema with 'type' field
return _format_schema(subschema) return _format_schema(subschema)
return _format_schema( return _format_schema(
@ -80,42 +106,38 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
result = {} result = {}
for key, val in schema.items(): for key, val in schema.items():
key = _camel_to_snake(key)
if key not in SUPPORTED_SCHEMA_KEYS: if key not in SUPPORTED_SCHEMA_KEYS:
continue continue
if key == "any_of":
val = [_format_schema(subschema) for subschema in val]
if key == "type": if key == "type":
key = "type_"
val = val.upper() val = val.upper()
elif key == "format": if key == "items":
if schema.get("type") == "string" and val != "enum":
continue
if schema.get("type") not in ("number", "integer", "string"):
continue
key = "format_"
elif key == "items":
val = _format_schema(val) val = _format_schema(val)
elif key == "properties": elif key == "properties":
val = {k: _format_schema(v) for k, v in val.items()} val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val result[key] = val
if result.get("enum") and result.get("type_") != "STRING": if result.get("enum") and result.get("type") != "STRING":
# enum is only allowed for STRING type. This is safe as long as the schema # enum is only allowed for STRING type. This is safe as long as the schema
# contains vol.Coerce for the respective type, for example: # contains vol.Coerce for the respective type, for example:
# vol.All(vol.Coerce(int), vol.In([1, 2, 3])) # vol.All(vol.Coerce(int), vol.In([1, 2, 3]))
result["type_"] = "STRING" result["type"] = "STRING"
result["enum"] = [str(item) for item in result["enum"]] result["enum"] = [str(item) for item in result["enum"]]
if result.get("type_") == "OBJECT" and not result.get("properties"): if result.get("type") == "OBJECT" and not result.get("properties"):
# An object with undefined properties is not supported by Gemini API. # An object with undefined properties is not supported by Gemini API.
# Fallback to JSON string. This will probably fail for most tools that want it, # Fallback to JSON string. This will probably fail for most tools that want it,
# but we don't have a better fallback strategy so far. # but we don't have a better fallback strategy so far.
result["properties"] = {"json": {"type_": "STRING"}} result["properties"] = {"json": {"type": "STRING"}}
result["required"] = [] result["required"] = []
return result return cast(Schema, result)
def _format_tool( def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]: ) -> Tool:
"""Format tool specification.""" """Format tool specification."""
if tool.parameters.schema: if tool.parameters.schema:
@ -125,16 +147,14 @@ def _format_tool(
else: else:
parameters = None parameters = None
return protos.Tool( return Tool(
{ function_declarations=[
"function_declarations": [ FunctionDeclaration(
{ name=tool.name,
"name": tool.name, description=tool.description,
"description": tool.description, parameters=parameters,
"parameters": parameters, )
}
] ]
}
) )
@ -151,15 +171,13 @@ def _escape_decode(value: Any) -> Any:
def _create_google_tool_response_content( def _create_google_tool_response_content(
content: list[conversation.ToolResultContent], content: list[conversation.ToolResultContent],
) -> protos.Content: ) -> Content:
"""Create a Google tool response content.""" """Create a Google tool response content."""
return protos.Content( return Content(
parts=[ parts=[
protos.Part( Part.from_function_response(
function_response=protos.FunctionResponse(
name=tool_result.tool_name, response=tool_result.tool_result name=tool_result.tool_name, response=tool_result.tool_result
) )
)
for tool_result in content for tool_result in content
] ]
) )
@ -169,33 +187,36 @@ def _convert_content(
content: conversation.UserContent content: conversation.UserContent
| conversation.AssistantContent | conversation.AssistantContent
| conversation.SystemContent, | conversation.SystemContent,
) -> genai_types.ContentDict: ) -> Content:
"""Convert HA content to Google content.""" """Convert HA content to Google content."""
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr] if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
role = "model" if content.role == "assistant" else content.role role = "model" if content.role == "assistant" else content.role
return {"role": role, "parts": content.content} return Content(
role=role,
parts=[
Part.from_text(text=content.content if content.content else ""),
],
)
# Handle the Assistant content with tool calls. # Handle the Assistant content with tool calls.
assert type(content) is conversation.AssistantContent assert type(content) is conversation.AssistantContent
parts = [] parts: list[Part] = []
if content.content: if content.content:
parts.append(protos.Part(text=content.content)) parts.append(Part.from_text(text=content.content))
if content.tool_calls: if content.tool_calls:
parts.extend( parts.extend(
[ [
protos.Part( Part.from_function_call(
function_call=protos.FunctionCall(
name=tool_call.tool_name, name=tool_call.tool_name,
args=_escape_decode(tool_call.tool_args), args=_escape_decode(tool_call.tool_args),
) )
)
for tool_call in content.tool_calls for tool_call in content.tool_calls
] ]
) )
return protos.Content({"role": "model", "parts": parts}) return Content(role="model", parts=parts)
class GoogleGenerativeAIConversationEntity( class GoogleGenerativeAIConversationEntity(
@ -209,6 +230,7 @@ class GoogleGenerativeAIConversationEntity(
def __init__(self, entry: ConfigEntry) -> None: def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self._genai_client = entry.runtime_data
self._attr_unique_id = entry.entry_id self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)}, identifiers={(DOMAIN, entry.entry_id)},
@ -273,7 +295,7 @@ class GoogleGenerativeAIConversationEntity(
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None tools: list[Tool | Callable[..., Any]] | None = None
if chat_log.llm_api: if chat_log.llm_api:
tools = [ tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer) _format_tool(tool, chat_log.llm_api.custom_serializer)
@ -288,13 +310,22 @@ class GoogleGenerativeAIConversationEntity(
"gemini-1.0" not in model_name and "gemini-pro" not in model_name "gemini-1.0" not in model_name and "gemini-pro" not in model_name
) )
prompt = chat_log.content[0].content # type: ignore[union-attr] prompt_content = cast(
messages: list[genai_types.ContentDict] = [] conversation.SystemContent,
chat_log.content[0],
)
if prompt_content.content:
prompt = prompt_content.content
else:
raise HomeAssistantError("Invalid prompt content")
messages: list[Content] = []
# Google groups tool results, we do not. Group them before sending. # Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = [] tool_results: list[conversation.ToolResultContent] = []
for chat_content in chat_log.content[1:]: for chat_content in chat_log.content[1:-1]:
if chat_content.role == "tool_result": if chat_content.role == "tool_result":
# mypy doesn't like picking a type based on checking shared property 'role' # mypy doesn't like picking a type based on checking shared property 'role'
tool_results.append(cast(conversation.ToolResultContent, chat_content)) tool_results.append(cast(conversation.ToolResultContent, chat_content))
@ -317,85 +348,93 @@ class GoogleGenerativeAIConversationEntity(
if tool_results: if tool_results:
messages.append(_create_google_tool_response_content(tool_results)) messages.append(_create_google_tool_response_content(tool_results))
generateContentConfig = GenerateContentConfig(
model = genai.GenerativeModel( temperature=self.entry.options.get(
model_name=model_name,
generation_config={
"temperature": self.entry.options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
), ),
"top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
"top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K), top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"max_output_tokens": self.entry.options.get( max_output_tokens=self.entry.options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
), ),
}, safety_settings=[
safety_settings={ SafetySetting(
"HARASSMENT": self.entry.options.get( category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD threshold=self.entry.options.get(
),
"HATE": self.entry.options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
), ),
"SEXUAL": self.entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
), ),
"DANGEROUS": self.entry.options.get( SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=self.entry.options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=self.entry.options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
), ),
}, ),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=self.entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
],
tools=tools or None, tools=tools or None,
system_instruction=prompt if supports_system_instruction else None, system_instruction=prompt if supports_system_instruction else None,
automatic_function_calling=AutomaticFunctionCallingConfig(
disable=True, maximum_remote_calls=None
),
) )
if not supports_system_instruction: if not supports_system_instruction:
messages = [ messages = [
{"role": "user", "parts": prompt}, Content(role="user", parts=[Part.from_text(text=prompt)]),
{"role": "model", "parts": "Ok"}, Content(role="model", parts=[Part.from_text(text="Ok")]),
*messages, *messages,
] ]
chat = self._genai_client.aio.chats.create(
chat = model.start_chat(history=messages) model=model_name, history=messages, config=generateContentConfig
chat_request = user_input.text )
chat_request: str | Content = user_input.text
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: try:
chat_response = await chat.send_message_async(chat_request) chat_response = await chat.send_message(message=chat_request)
except (
GoogleAPIError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
LOGGER.error("Error sending message: %s %s", type(err), err)
if isinstance( if chat_response.prompt_feedback:
err, genai_types.StopCandidateException raise HomeAssistantError(
) and "finish_reason: SAFETY\n" in str(err): f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}"
error = "The message got blocked by your safety settings"
else:
error = (
f"Sorry, I had a problem talking to Google Generative AI: {err}"
) )
except (
APIError,
ValueError,
) as err:
LOGGER.error("Error sending message: %s %s", type(err), err)
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
raise HomeAssistantError(error) from err raise HomeAssistantError(error) from err
LOGGER.debug("Response: %s", chat_response.parts) response_parts = chat_response.candidates[0].content.parts
if not chat_response.parts: if not response_parts:
raise HomeAssistantError( raise HomeAssistantError(
"Sorry, I had a problem getting a response from Google Generative AI." "Sorry, I had a problem getting a response from Google Generative AI."
) )
content = " ".join( content = " ".join(
[part.text.strip() for part in chat_response.parts if part.text] [part.text.strip() for part in response_parts if part.text]
) )
tool_calls = [] tool_calls = []
for part in chat_response.parts: for part in response_parts:
if not part.function_call: if not part.function_call:
continue continue
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001 tool_call = part.function_call
tool_name = tool_call["name"] tool_name = tool_call.name
tool_args = _escape_decode(tool_call["args"]) tool_args = _escape_decode(tool_call.args)
tool_calls.append( tool_calls.append(
llm.ToolInput(tool_name=tool_name, tool_args=tool_args) llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
) )
@ -418,7 +457,7 @@ class GoogleGenerativeAIConversationEntity(
response = intent.IntentResponse(language=user_input.language) response = intent.IntentResponse(language=user_input.language)
response.async_set_speech( response.async_set_speech(
" ".join([part.text.strip() for part in chat_response.parts if part.text]) " ".join([part.text.strip() for part in response_parts if part.text])
) )
return conversation.ConversationResult( return conversation.ConversationResult(
response=response, conversation_id=chat_log.conversation_id response=response, conversation_id=chat_log.conversation_id

View File

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["google-generativeai==0.8.2"] "requirements": ["google-genai==1.1.0"]
} }

2
requirements_all.txt generated
View File

@ -1033,7 +1033,7 @@ google-cloud-speech==2.27.0
google-cloud-texttospeech==2.17.2 google-cloud-texttospeech==2.17.2
# homeassistant.components.google_generative_ai_conversation # homeassistant.components.google_generative_ai_conversation
google-generativeai==0.8.2 google-genai==1.1.0
# homeassistant.components.nest # homeassistant.components.nest
google-nest-sdm==7.1.3 google-nest-sdm==7.1.3

View File

@ -883,7 +883,7 @@ google-cloud-speech==2.27.0
google-cloud-texttospeech==2.17.2 google-cloud-texttospeech==2.17.2
# homeassistant.components.google_generative_ai_conversation # homeassistant.components.google_generative_ai_conversation
google-generativeai==0.8.2 google-genai==1.1.0
# homeassistant.components.nest # homeassistant.components.nest
google-nest-sdm==7.1.3 google-nest-sdm==7.1.3

View File

@ -1 +1,31 @@
"""Tests for the Google Generative AI Conversation integration.""" """Tests for the Google Generative AI Conversation integration."""
from unittest.mock import Mock
from google.genai.errors import ClientError
import requests
CLIENT_ERROR_500 = ClientError(
500,
Mock(
__class__=requests.Response,
json=Mock(
return_value={
"message": "Internal Server Error",
"status": "internal-error",
}
),
),
)
CLIENT_ERROR_API_KEY_INVALID = ClientError(
400,
Mock(
__class__=requests.Response,
json=Mock(
return_value={
"message": "'reason': API_KEY_INVALID",
"status": "unauthorized",
}
),
),
)

View File

@ -1,7 +1,6 @@
"""Tests helpers.""" """Tests helpers."""
from collections.abc import Generator from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest import pytest
@ -15,14 +14,7 @@ from tests.common import MockConfigEntry
@pytest.fixture @pytest.fixture
def mock_genai() -> Generator[None]: def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Mock the genai call in async_setup_entry."""
with patch("google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.get_model"):
yield
@pytest.fixture
def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry:
"""Mock a config entry.""" """Mock a config entry."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain="google_generative_ai_conversation", domain="google_generative_ai_conversation",
@ -31,18 +23,21 @@ def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry:
"api_key": "bla", "api_key": "bla",
}, },
) )
entry.runtime_data = Mock()
entry.add_to_hass(hass) entry.add_to_hass(hass)
return entry return entry
@pytest.fixture @pytest.fixture
def mock_config_entry_with_assist( async def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry with assist.""" """Mock a config entry with assist."""
with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
) )
await hass.async_block_till_done()
return mock_config_entry return mock_config_entry
@ -51,7 +46,10 @@ async def mock_init_component(
hass: HomeAssistant, mock_config_entry: ConfigEntry hass: HomeAssistant, mock_config_entry: ConfigEntry
) -> None: ) -> None:
"""Initialize integration.""" """Initialize integration."""
assert await async_setup_component(hass, "google_generative_ai_conversation", {}) with patch("google.genai.models.AsyncModels.get"):
assert await async_setup_component(
hass, "google_generative_ai_conversation", {}
)
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -6,106 +6,26 @@
tuple( tuple(
), ),
dict({ dict({
'generation_config': dict({ 'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.ARRAY: 'ARRAY'>, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format='lower', items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=[Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.NUMBER: 'NUMBER'>, description=None, enum=None, format=None, items=None, properties=None, required=None), Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.INTEGER: 'INTEGER'>, description=None, enum=None, format=None, items=None, properties=None, required=None)], max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-2.0-flash',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'tools': list([
function_declarations {
name: "test_tool"
description: "Test function"
parameters {
type_: OBJECT
properties {
key: "param3"
value {
type_: OBJECT
properties {
key: "json"
value {
type_: STRING
}
}
}
}
properties {
key: "param2"
value {
type_: NUMBER
}
}
properties {
key: "param1"
value {
type_: ARRAY
description: "Test parameters"
items {
type_: STRING
}
}
}
}
}
,
]),
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([ 'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]), ]),
'model': 'models/gemini-2.0-flash',
}), }),
), ),
tuple( tuple(
'().start_chat().send_message_async', '().send_message',
tuple( tuple(
'Please call the test function',
), ),
dict({ dict({
'message': 'Please call the test function',
}), }),
), ),
tuple( tuple(
'().start_chat().send_message_async', '().send_message',
tuple( tuple(
parts {
function_response {
name: "test_tool"
response {
fields {
key: "result"
value {
string_value: "Test response"
}
}
}
}
}
,
), ),
dict({ dict({
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
}), }),
), ),
]) ])
@ -117,75 +37,26 @@
tuple( tuple(
), ),
dict({ dict({
'generation_config': dict({ 'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-2.0-flash',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'system_instruction': '''
Current time is 05:00:00. Today's date is 2024-05-24.
You are a voice assistant for Home Assistant.
Answer questions about the world truthfully.
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
''',
'tools': list([
function_declarations {
name: "test_tool"
description: "Test function"
}
,
]),
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([ 'history': list([
dict({
'parts': 'Please call the test function',
'role': 'user',
}),
]), ]),
'model': 'models/gemini-2.0-flash',
}), }),
), ),
tuple( tuple(
'().start_chat().send_message_async', '().send_message',
tuple( tuple(
'Please call the test function',
), ),
dict({ dict({
'message': 'Please call the test function',
}), }),
), ),
tuple( tuple(
'().start_chat().send_message_async', '().send_message',
tuple( tuple(
parts {
function_response {
name: "test_tool"
response {
fields {
key: "result"
value {
string_value: "Test response"
}
}
}
}
}
,
), ),
dict({ dict({
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
}), }),
), ),
]) ])

View File

@ -6,21 +6,11 @@
tuple( tuple(
), ),
dict({ dict({
'model_name': 'models/gemini-2.0-flash', 'contents': list([
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'Describe this image from my doorbell camera', 'Describe this image from my doorbell camera',
dict({ b'image bytes',
'data': b'image bytes',
'mime_type': 'image/jpeg',
}),
]), ]),
), 'model': 'models/gemini-2.0-flash',
dict({
}), }),
), ),
]) ])
@ -32,17 +22,10 @@
tuple( tuple(
), ),
dict({ dict({
'model_name': 'models/gemini-2.0-flash', 'contents': list([
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'Write an opening speech for a Home Assistant release party', 'Write an opening speech for a Home Assistant release party',
]), ]),
), 'model': 'models/gemini-2.0-flash',
dict({
}), }),
), ),
]) ])

View File

@ -1,10 +1,9 @@
"""Test the Google Generative AI Conversation config flow.""" """Test the Google Generative AI Conversation config flow."""
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import Mock, patch
from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module
import pytest import pytest
from requests.exceptions import Timeout
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.google_generative_ai_conversation.config_flow import ( from homeassistant.components.google_generative_ai_conversation.config_flow import (
@ -33,6 +32,8 @@ from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -41,30 +42,37 @@ def mock_models():
"""Mock the model list API.""" """Mock the model list API."""
model_20_flash = Mock( model_20_flash = Mock(
display_name="Gemini 2.0 Flash", display_name="Gemini 2.0 Flash",
supported_generation_methods=["generateContent"], supported_actions=["generateContent"],
) )
model_20_flash.name = "models/gemini-2.0-flash" model_20_flash.name = "models/gemini-2.0-flash"
model_15_flash = Mock( model_15_flash = Mock(
display_name="Gemini 1.5 Flash", display_name="Gemini 1.5 Flash",
supported_generation_methods=["generateContent"], supported_actions=["generateContent"],
) )
model_15_flash.name = "models/gemini-1.5-flash-latest" model_15_flash.name = "models/gemini-1.5-flash-latest"
model_15_pro = Mock( model_15_pro = Mock(
display_name="Gemini 1.5 Pro", display_name="Gemini 1.5 Pro",
supported_generation_methods=["generateContent"], supported_actions=["generateContent"],
) )
model_15_pro.name = "models/gemini-1.5-pro-latest" model_15_pro.name = "models/gemini-1.5-pro-latest"
model_10_pro = Mock( model_10_pro = Mock(
display_name="Gemini 1.0 Pro", display_name="Gemini 1.0 Pro",
supported_generation_methods=["generateContent"], supported_actions=["generateContent"],
) )
model_10_pro.name = "models/gemini-pro" model_10_pro.name = "models/gemini-pro"
async def models_pager():
yield model_20_flash
yield model_15_flash
yield model_15_pro
yield model_10_pro
with patch( with patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", "google.genai.models.AsyncModels.list",
return_value=iter([model_20_flash, model_15_flash, model_15_pro, model_10_pro]), return_value=models_pager(),
): ):
yield yield
@ -86,7 +94,7 @@ async def test_form(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models", "google.genai.models.AsyncModels.list",
), ),
patch( patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry", "homeassistant.components.google_generative_ai_conversation.async_setup_entry",
@ -170,7 +178,11 @@ async def test_options_switching(
expected_options, expected_options,
) -> None: ) -> None:
"""Test the options form.""" """Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options) with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry(
mock_config_entry, options=current_options
)
await hass.async_block_till_done()
options_flow = await hass.config_entries.options.async_init( options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
@ -195,17 +207,15 @@ async def test_options_switching(
("side_effect", "error"), ("side_effect", "error"),
[ [
( (
ClientError("some error"), CLIENT_ERROR_500,
"cannot_connect", "cannot_connect",
), ),
( (
DeadlineExceeded("deadline exceeded"), Timeout("deadline exceeded"),
"cannot_connect", "cannot_connect",
), ),
( (
ClientError( CLIENT_ERROR_API_KEY_INVALID,
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
),
"invalid_auth", "invalid_auth",
), ),
(Exception, "unknown"), (Exception, "unknown"),
@ -217,12 +227,7 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
mock_client = AsyncMock() with patch("google.genai.models.AsyncModels.list", side_effect=side_effect):
mock_client.list_models.side_effect = side_effect
with patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient",
return_value=mock_client,
):
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{ {
@ -259,7 +264,7 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models", "google.genai.models.AsyncModels.list",
), ),
patch( patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry", "homeassistant.components.google_generative_ai_conversation.async_setup_entry",

View File

@ -1,12 +1,10 @@
"""Tests for the Google Generative AI Conversation integration conversation platform.""" """Tests for the Google Generative AI Conversation integration conversation platform."""
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, Mock, patch
from freezegun import freeze_time from freezegun import freeze_time
from google.ai.generativelanguage_v1beta.types.content import FunctionCall from google.genai.types import FunctionCall
from google.api_core.exceptions import GoogleAPIError
import google.generativeai.types as genai_types
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
@ -22,6 +20,8 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm from homeassistant.helpers import intent, llm
from . import CLIENT_ERROR_500
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -51,7 +51,7 @@ async def test_function_call(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test function calling.""" """Test function calling."""
agent_id = mock_config_entry_with_assist.entry_id agent_id = "conversation.google_generative_ai_conversation"
context = Context() context = Context()
mock_tool = AsyncMock() mock_tool = AsyncMock()
@ -69,12 +69,12 @@ async def test_function_call(
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_create.return_value.send_message = mock_chat
chat_response = MagicMock() chat_response = Mock(prompt_feedback=None)
mock_chat.send_message_async.return_value = chat_response mock_chat.return_value = chat_response
mock_part = MagicMock() mock_part = Mock()
mock_part.text = "" mock_part.text = ""
mock_part.function_call = FunctionCall( mock_part.function_call = FunctionCall(
name="test_tool", name="test_tool",
@ -92,7 +92,7 @@ async def test_function_call(
return {"result": "Test response"} return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part] chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"Please call the test function", "Please call the test function",
@ -104,20 +104,28 @@ async def test_function_call(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] mock_tool_call = mock_create.mock_calls[2][2]["message"]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) assert mock_tool_call.model_dump() == {
assert mock_tool_call == {
"parts": [ "parts": [
{ {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": { "function_response": {
"id": None,
"name": "test_tool", "name": "test_tool",
"response": { "response": {
"result": "Test response", "result": "Test response",
}, },
}, },
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}, },
], ],
"role": "", "role": None,
} }
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
@ -139,7 +147,7 @@ async def test_function_call(
device_id="test_device", device_id="test_device",
), ),
) )
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
# Test conversating tracing # Test conversating tracing
traces = trace.async_get_traces() traces = trace.async_get_traces()
@ -170,7 +178,7 @@ async def test_function_call_without_parameters(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test function calling without parameters.""" """Test function calling without parameters."""
agent_id = mock_config_entry_with_assist.entry_id agent_id = "conversation.google_generative_ai_conversation"
context = Context() context = Context()
mock_tool = AsyncMock() mock_tool = AsyncMock()
@ -180,12 +188,12 @@ async def test_function_call_without_parameters(
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_create.return_value.send_message = mock_chat
chat_response = MagicMock() chat_response = Mock(prompt_feedback=None)
mock_chat.send_message_async.return_value = chat_response mock_chat.return_value = chat_response
mock_part = MagicMock() mock_part = Mock()
mock_part.text = "" mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={}) mock_part.function_call = FunctionCall(name="test_tool", args={})
@ -197,7 +205,7 @@ async def test_function_call_without_parameters(
return {"result": "Test response"} return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part] chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"Please call the test function", "Please call the test function",
@ -209,20 +217,28 @@ async def test_function_call_without_parameters(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] mock_tool_call = mock_create.mock_calls[2][2]["message"]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) assert mock_tool_call.model_dump() == {
assert mock_tool_call == {
"parts": [ "parts": [
{ {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": { "function_response": {
"id": None,
"name": "test_tool", "name": "test_tool",
"response": { "response": {
"result": "Test response", "result": "Test response",
}, },
}, },
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}, },
], ],
"role": "", "role": None,
} }
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
@ -241,7 +257,7 @@ async def test_function_call_without_parameters(
device_id="test_device", device_id="test_device",
), ),
) )
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
@patch( @patch(
@ -254,7 +270,7 @@ async def test_function_exception(
mock_config_entry_with_assist: MockConfigEntry, mock_config_entry_with_assist: MockConfigEntry,
) -> None: ) -> None:
"""Test exception in function calling.""" """Test exception in function calling."""
agent_id = mock_config_entry_with_assist.entry_id agent_id = "conversation.google_generative_ai_conversation"
context = Context() context = Context()
mock_tool = AsyncMock() mock_tool = AsyncMock()
@ -270,12 +286,12 @@ async def test_function_exception(
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_create.return_value.send_message = mock_chat
chat_response = MagicMock() chat_response = Mock(prompt_feedback=None)
mock_chat.send_message_async.return_value = chat_response mock_chat.return_value = chat_response
mock_part = MagicMock() mock_part = Mock()
mock_part.text = "" mock_part.text = ""
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1}) mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
@ -287,7 +303,7 @@ async def test_function_exception(
raise HomeAssistantError("Test tool exception") raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call
chat_response.parts = [mock_part] chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"Please call the test function", "Please call the test function",
@ -299,21 +315,29 @@ async def test_function_exception(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] mock_tool_call = mock_create.mock_calls[2][2]["message"]
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) assert mock_tool_call.model_dump() == {
assert mock_tool_call == {
"parts": [ "parts": [
{ {
"code_execution_result": None,
"executable_code": None,
"file_data": None,
"function_call": None,
"function_response": { "function_response": {
"id": None,
"name": "test_tool", "name": "test_tool",
"response": { "response": {
"error": "HomeAssistantError", "error": "HomeAssistantError",
"error_text": "Test tool exception", "error_text": "Test tool exception",
}, },
}, },
"inline_data": None,
"text": None,
"thought": None,
"video_metadata": None,
}, },
], ],
"role": "", "role": None,
} }
mock_tool.async_call.assert_awaited_once_with( mock_tool.async_call.assert_awaited_once_with(
hass, hass,
@ -338,18 +362,22 @@ async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test that client errors are caught.""" """Test that client errors are caught."""
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_create.return_value.send_message = mock_chat
mock_chat.send_message_async.side_effect = GoogleAPIError("some error") mock_chat.side_effect = CLIENT_ERROR_500
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == ( assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI: some error" "Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}"
) )
@ -358,20 +386,24 @@ async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test blocked response.""" """Test blocked response."""
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_create.return_value.send_message = mock_chat
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException( chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY"))
"finish_reason: SAFETY\n" mock_chat.return_value = chat_response
)
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == ( assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"The message got blocked by your safety settings" "The message got blocked due to content violations, reason: SAFETY"
) )
@ -380,14 +412,18 @@ async def test_empty_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test empty response.""" """Test empty response."""
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_create.return_value.send_message = mock_chat
chat_response = MagicMock() chat_response = Mock(prompt_feedback=None)
mock_chat.send_message_async.return_value = chat_response mock_chat.return_value = chat_response
chat_response.parts = [] chat_response.candidates = [Mock(content=Mock(parts=[]))]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass,
"hello",
None,
Context(),
agent_id="conversation.google_generative_ai_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
@ -402,17 +438,19 @@ async def test_converse_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test handling ChatLog raising ConverseError.""" """Test handling ChatLog raising ConverseError."""
with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
mock_config_entry, mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
) )
await hass.async_block_till_done()
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"hello", "hello",
None, None,
Context(), Context(),
agent_id=mock_config_entry.entry_id, agent_id="conversation.google_generative_ai_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
@ -449,31 +487,39 @@ async def test_escape_decode() -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("openapi", "protobuf"), ("openapi", "genai_schema"),
[ [
( (
{"type": "string", "enum": ["a", "b", "c"]}, {"type": "string", "enum": ["a", "b", "c"]},
{"type_": "STRING", "enum": ["a", "b", "c"]}, {"type": "STRING", "enum": ["a", "b", "c"]},
), ),
( (
{"type": "integer", "enum": [1, 2, 3]}, {"type": "integer", "enum": [1, 2, 3]},
{"type_": "STRING", "enum": ["1", "2", "3"]}, {"type": "STRING", "enum": ["1", "2", "3"]},
),
(
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
), ),
({"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"type_": "INTEGER"}),
( (
{ {
"anyOf": [ "any_of": [
{"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"any_of": [{"type": "integer"}, {"type": "number"}]},
{"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"any_of": [{"type": "integer"}, {"type": "number"}]},
]
},
{
"any_of": [
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
] ]
}, },
{"type_": "INTEGER"},
), ),
({"type": "string", "format": "lower"}, {"type_": "STRING"}), ({"type": "string", "format": "lower"}, {"format": "lower", "type": "STRING"}),
({"type": "boolean", "format": "bool"}, {"type_": "BOOLEAN"}), ({"type": "boolean", "format": "bool"}, {"format": "bool", "type": "BOOLEAN"}),
( (
{"type": "number", "format": "percent"}, {"type": "number", "format": "percent"},
{"type_": "NUMBER", "format_": "percent"}, {"type": "NUMBER", "format": "percent"},
), ),
( (
{ {
@ -482,25 +528,25 @@ async def test_escape_decode() -> None:
"required": [], "required": [],
}, },
{ {
"type_": "OBJECT", "type": "OBJECT",
"properties": {"var": {"type_": "STRING"}}, "properties": {"var": {"type": "STRING"}},
"required": [], "required": [],
}, },
), ),
( (
{"type": "object", "additionalProperties": True}, {"type": "object", "additionalProperties": True},
{ {
"type_": "OBJECT", "type": "OBJECT",
"properties": {"json": {"type_": "STRING"}}, "properties": {"json": {"type": "STRING"}},
"required": [], "required": [],
}, },
), ),
( (
{"type": "array", "items": {"type": "string"}}, {"type": "array", "items": {"type": "string"}},
{"type_": "ARRAY", "items": {"type_": "STRING"}}, {"type": "ARRAY", "items": {"type": "STRING"}},
), ),
], ],
) )
async def test_format_schema(openapi, protobuf) -> None: async def test_format_schema(openapi, genai_schema) -> None:
"""Test _format_schema.""" """Test _format_schema."""
assert _format_schema(openapi) == protobuf assert _format_schema(openapi) == genai_schema

View File

@ -1,16 +1,17 @@
"""Tests for the Google Generative AI Conversation integration.""" """Tests for the Google Generative AI Conversation integration."""
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, Mock, patch
from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module
import pytest import pytest
from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -24,12 +25,14 @@ async def test_generate_content_service_without_images(
"party for the latest version of Home Assistant!" "party for the latest version of Home Assistant!"
) )
with patch("google.generativeai.GenerativeModel") as mock_model: with patch(
mock_response = MagicMock() "google.genai.models.AsyncModels.generate_content",
mock_response.text = stubbed_generated_content return_value=Mock(
mock_model.return_value.generate_content_async = AsyncMock( text=stubbed_generated_content,
return_value=mock_response prompt_feedback=None,
) candidates=[Mock()],
),
) as mock_generate:
response = await hass.services.async_call( response = await hass.services.async_call(
"google_generative_ai_conversation", "google_generative_ai_conversation",
"generate_content", "generate_content",
@ -41,7 +44,7 @@ async def test_generate_content_service_without_images(
assert response == { assert response == {
"text": stubbed_generated_content, "text": stubbed_generated_content,
} }
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
@ -54,19 +57,21 @@ async def test_generate_content_service_with_image(
) )
with ( with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch( patch(
"homeassistant.components.google_generative_ai_conversation.Path.read_bytes", "google.genai.models.AsyncModels.generate_content",
return_value=Mock(
text=stubbed_generated_content,
prompt_feedback=None,
candidates=[Mock()],
),
) as mock_generate,
patch(
"homeassistant.components.google_generative_ai_conversation.Image.open",
return_value=b"image bytes", return_value=b"image bytes",
), ),
patch("pathlib.Path.exists", return_value=True), patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True), patch.object(hass.config, "is_allowed_path", return_value=True),
): ):
mock_response = MagicMock()
mock_response.text = stubbed_generated_content
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
)
response = await hass.services.async_call( response = await hass.services.async_call(
"google_generative_ai_conversation", "google_generative_ai_conversation",
"generate_content", "generate_content",
@ -81,7 +86,7 @@ async def test_generate_content_service_with_image(
assert response == { assert response == {
"text": stubbed_generated_content, "text": stubbed_generated_content,
} }
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
@ -90,12 +95,15 @@ async def test_generate_content_service_error(
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test generate content service handles errors.""" """Test generate content service handles errors."""
with patch("google.generativeai.GenerativeModel") as mock_model: with (
mock_model.return_value.generate_content_async = AsyncMock( patch(
side_effect=ClientError("reason") "google.genai.models.AsyncModels.generate_content",
) side_effect=CLIENT_ERROR_500,
with pytest.raises( ),
HomeAssistantError, match="Error generating content: None reason" pytest.raises(
HomeAssistantError,
match="Error generating content: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}",
),
): ):
await hass.services.async_call( await hass.services.async_call(
"google_generative_ai_conversation", "google_generative_ai_conversation",
@ -113,14 +121,15 @@ async def test_generate_content_response_has_empty_parts(
) -> None: ) -> None:
"""Test generate content service handles response with empty parts.""" """Test generate content service handles response with empty parts."""
with ( with (
patch("google.generativeai.GenerativeModel") as mock_model, patch(
"google.genai.models.AsyncModels.generate_content",
return_value=Mock(
prompt_feedback=None,
candidates=[Mock(content=Mock(parts=[]))],
),
),
pytest.raises(HomeAssistantError, match="Unknown error generating content"),
): ):
mock_response = MagicMock()
mock_response.parts = []
mock_model.return_value.generate_content_async = AsyncMock(
return_value=mock_response
)
with pytest.raises(HomeAssistantError, match="Error generating content"):
await hass.services.async_call( await hass.services.async_call(
"google_generative_ai_conversation", "google_generative_ai_conversation",
"generate_content", "generate_content",
@ -211,19 +220,17 @@ async def test_generate_content_service_with_non_image(hass: HomeAssistant) -> N
("side_effect", "state", "reauth"), ("side_effect", "state", "reauth"),
[ [
( (
ClientError("some error"), CLIENT_ERROR_500,
ConfigEntryState.SETUP_ERROR, ConfigEntryState.SETUP_ERROR,
False, False,
), ),
( (
DeadlineExceeded("deadline exceeded"), Timeout,
ConfigEntryState.SETUP_RETRY, ConfigEntryState.SETUP_RETRY,
False, False,
), ),
( (
ClientError( CLIENT_ERROR_API_KEY_INVALID,
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
),
ConfigEntryState.SETUP_ERROR, ConfigEntryState.SETUP_ERROR,
True, True,
), ),
@ -235,10 +242,7 @@ async def test_config_entry_error(
"""Test different configuration entry errors.""" """Test different configuration entry errors."""
mock_client = AsyncMock() mock_client = AsyncMock()
mock_client.get_model.side_effect = side_effect mock_client.get_model.side_effect = side_effect
with patch( with patch("google.genai.models.AsyncModels.get", side_effect=side_effect):
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient",
return_value=mock_client,
):
assert not await hass.config_entries.async_setup(mock_config_entry.entry_id) assert not await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state == state assert mock_config_entry.state == state