mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Swap the Gemini SDK to the newly released Unified SDK (#138246)
* Swapped the old GenAI client with the newly realeased one * Fixed the Generate Content Action, Config Flow loading and code cleanup * Add a function to mask the issues with Tools which start with Hass * Fix most tests * google-genai==1.1.0 * fixes * Fixed the remaining tests * Adressed comments --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: tronikos <tronikos@users.noreply.github.com>
This commit is contained in:
parent
baa3b15dbc
commit
3160b7baa0
@ -5,11 +5,10 @@ from __future__ import annotations
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
from google.ai import generativelanguage_v1beta
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPIError
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.types as genai_types
|
||||
from google import genai # type: ignore[attr-defined]
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from PIL import Image
|
||||
from requests.exceptions import Timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -29,7 +28,13 @@ from homeassistant.exceptions import (
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
|
||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||
CONF_IMAGE_FILENAME = "image_filename"
|
||||
@ -37,6 +42,8 @@ CONF_IMAGE_FILENAME = "image_filename"
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
|
||||
type GoogleGenerativeAIConfigEntry = ConfigEntry[genai.Client]
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Google Generative AI Conversation."""
|
||||
@ -44,42 +51,47 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
async def generate_content(call: ServiceCall) -> ServiceResponse:
|
||||
"""Generate content from text and optionally images."""
|
||||
prompt_parts = [call.data[CONF_PROMPT]]
|
||||
image_filenames = call.data[CONF_IMAGE_FILENAME]
|
||||
for image_filename in image_filenames:
|
||||
if not hass.config.is_allowed_path(image_filename):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot read `{image_filename}`, no access to path; "
|
||||
"`allowlist_external_dirs` may need to be adjusted in "
|
||||
"`configuration.yaml`"
|
||||
)
|
||||
if not Path(image_filename).exists():
|
||||
raise HomeAssistantError(f"`{image_filename}` does not exist")
|
||||
mime_type, _ = mimetypes.guess_type(image_filename)
|
||||
if mime_type is None or not mime_type.startswith("image"):
|
||||
raise HomeAssistantError(f"`{image_filename}` is not an image")
|
||||
prompt_parts.append(
|
||||
{
|
||||
"mime_type": mime_type,
|
||||
"data": await hass.async_add_executor_job(
|
||||
Path(image_filename).read_bytes
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
model = genai.GenerativeModel(model_name=RECOMMENDED_CHAT_MODEL)
|
||||
def append_images_to_prompt():
|
||||
image_filenames = call.data[CONF_IMAGE_FILENAME]
|
||||
for image_filename in image_filenames:
|
||||
if not hass.config.is_allowed_path(image_filename):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot read `{image_filename}`, no access to path; "
|
||||
"`allowlist_external_dirs` may need to be adjusted in "
|
||||
"`configuration.yaml`"
|
||||
)
|
||||
if not Path(image_filename).exists():
|
||||
raise HomeAssistantError(f"`{image_filename}` does not exist")
|
||||
mime_type, _ = mimetypes.guess_type(image_filename)
|
||||
if mime_type is None or not mime_type.startswith("image"):
|
||||
raise HomeAssistantError(f"`{image_filename}` is not an image")
|
||||
prompt_parts.append(Image.open(image_filename))
|
||||
|
||||
await hass.async_add_executor_job(append_images_to_prompt)
|
||||
|
||||
config_entry: GoogleGenerativeAIConfigEntry = hass.config_entries.async_entries(
|
||||
DOMAIN
|
||||
)[0]
|
||||
client = config_entry.runtime_data
|
||||
|
||||
try:
|
||||
response = await model.generate_content_async(prompt_parts)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts
|
||||
)
|
||||
except (
|
||||
GoogleAPIError,
|
||||
APIError,
|
||||
ValueError,
|
||||
genai_types.BlockedPromptException,
|
||||
genai_types.StopCandidateException,
|
||||
) as err:
|
||||
raise HomeAssistantError(f"Error generating content: {err}") from err
|
||||
|
||||
if not response.parts:
|
||||
raise HomeAssistantError("Error generating content")
|
||||
if response.prompt_feedback:
|
||||
raise HomeAssistantError(
|
||||
f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}"
|
||||
)
|
||||
|
||||
if not response.candidates[0].content.parts:
|
||||
raise HomeAssistantError("Unknown error generating content")
|
||||
|
||||
return {"text": response.text}
|
||||
|
||||
@ -100,30 +112,34 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
|
||||
) -> bool:
|
||||
"""Set up Google Generative AI Conversation from a config entry."""
|
||||
genai.configure(api_key=entry.data[CONF_API_KEY])
|
||||
|
||||
try:
|
||||
client = generativelanguage_v1beta.ModelServiceAsyncClient(
|
||||
client_options=ClientOptions(api_key=entry.data[CONF_API_KEY])
|
||||
client = genai.Client(api_key=entry.data[CONF_API_KEY])
|
||||
await client.aio.models.get(
|
||||
model=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||
)
|
||||
await client.get_model(
|
||||
name=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), timeout=5.0
|
||||
)
|
||||
except (GoogleAPIError, ValueError) as err:
|
||||
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
|
||||
raise ConfigEntryAuthFailed(err) from err
|
||||
if isinstance(err, DeadlineExceeded):
|
||||
except (APIError, Timeout) as err:
|
||||
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
|
||||
raise ConfigEntryAuthFailed(err.message) from err
|
||||
if isinstance(err, Timeout):
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
raise ConfigEntryError(err) from err
|
||||
else:
|
||||
entry.runtime_data = client
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_unload_entry(
|
||||
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
|
||||
) -> bool:
|
||||
"""Unload GoogleGenerativeAI."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
|
@ -3,15 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from google.ai import generativelanguage_v1beta
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.api_core.exceptions import ClientError, GoogleAPIError
|
||||
import google.generativeai as genai
|
||||
from google import genai # type: ignore[attr-defined]
|
||||
from google.genai.errors import APIError, ClientError
|
||||
from requests.exceptions import Timeout
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import (
|
||||
@ -53,6 +51,7 @@ from .const import (
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -70,15 +69,20 @@ RECOMMENDED_OPTIONS = {
|
||||
}
|
||||
|
||||
|
||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
async def validate_input(data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
|
||||
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
||||
"""
|
||||
client = generativelanguage_v1beta.ModelServiceAsyncClient(
|
||||
client_options=ClientOptions(api_key=data[CONF_API_KEY])
|
||||
client = genai.Client(api_key=data[CONF_API_KEY])
|
||||
await client.aio.models.list(
|
||||
config={
|
||||
"http_options": {
|
||||
"timeout": TIMEOUT_MILLIS,
|
||||
},
|
||||
"query_base": True,
|
||||
}
|
||||
)
|
||||
await client.list_models(timeout=5.0)
|
||||
|
||||
|
||||
class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
@ -93,9 +97,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors: dict[str, str] = {}
|
||||
if user_input is not None:
|
||||
try:
|
||||
await validate_input(self.hass, user_input)
|
||||
except GoogleAPIError as err:
|
||||
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
|
||||
await validate_input(user_input)
|
||||
except (APIError, Timeout) as err:
|
||||
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
|
||||
errors["base"] = "invalid_auth"
|
||||
else:
|
||||
errors["base"] = "cannot_connect"
|
||||
@ -166,6 +170,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||
self.last_rendered_recommended = config_entry.options.get(
|
||||
CONF_RECOMMENDED, False
|
||||
)
|
||||
self._genai_client = config_entry.runtime_data
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@ -188,7 +193,9 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||
}
|
||||
|
||||
schema = await google_generative_ai_config_option_schema(self.hass, options)
|
||||
schema = await google_generative_ai_config_option_schema(
|
||||
self.hass, options, self._genai_client
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
@ -198,6 +205,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||
async def google_generative_ai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
options: dict[str, Any] | MappingProxyType[str, Any],
|
||||
genai_client: genai.Client,
|
||||
) -> dict:
|
||||
"""Return a schema for Google Generative AI completion options."""
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
@ -236,18 +244,21 @@ async def google_generative_ai_config_option_schema(
|
||||
if options.get(CONF_RECOMMENDED):
|
||||
return schema
|
||||
|
||||
api_models = await hass.async_add_executor_job(partial(genai.list_models))
|
||||
|
||||
api_models_pager = await genai_client.aio.models.list(config={"query_base": True})
|
||||
api_models = [api_model async for api_model in api_models_pager]
|
||||
models = [
|
||||
SelectOptionDict(
|
||||
label=api_model.display_name,
|
||||
value=api_model.name,
|
||||
)
|
||||
for api_model in sorted(api_models, key=lambda x: x.display_name)
|
||||
for api_model in sorted(api_models, key=lambda x: x.display_name or "")
|
||||
if (
|
||||
api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro
|
||||
and api_model.display_name
|
||||
and api_model.name
|
||||
and api_model.supported_actions
|
||||
and "vision" not in api_model.name
|
||||
and "generateContent" in api_model.supported_generation_methods
|
||||
and "generateContent" in api_model.supported_actions
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -22,3 +22,5 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
|
@ -6,11 +6,18 @@ import codecs
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from google.api_core.exceptions import GoogleAPIError
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import protos
|
||||
import google.generativeai.types as genai_types
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from google.genai.errors import APIError
|
||||
from google.genai.types import (
|
||||
AutomaticFunctionCallingConfig,
|
||||
Content,
|
||||
FunctionDeclaration,
|
||||
GenerateContentConfig,
|
||||
HarmCategory,
|
||||
Part,
|
||||
SafetySetting,
|
||||
Schema,
|
||||
Tool,
|
||||
)
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
@ -57,21 +64,40 @@ async def async_setup_entry(
|
||||
|
||||
|
||||
SUPPORTED_SCHEMA_KEYS = {
|
||||
"type",
|
||||
"format",
|
||||
"description",
|
||||
"min_items",
|
||||
"example",
|
||||
"property_ordering",
|
||||
"pattern",
|
||||
"minimum",
|
||||
"default",
|
||||
"any_of",
|
||||
"max_length",
|
||||
"title",
|
||||
"min_properties",
|
||||
"min_length",
|
||||
"max_items",
|
||||
"maximum",
|
||||
"nullable",
|
||||
"max_properties",
|
||||
"type",
|
||||
"description",
|
||||
"enum",
|
||||
"format",
|
||||
"items",
|
||||
"properties",
|
||||
"required",
|
||||
}
|
||||
|
||||
|
||||
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Format the schema to protobuf."""
|
||||
if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")):
|
||||
for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
|
||||
def _camel_to_snake(name: str) -> str:
|
||||
"""Convert camel case to snake case."""
|
||||
return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_")
|
||||
|
||||
|
||||
def _format_schema(schema: dict[str, Any]) -> Schema:
|
||||
"""Format the schema to be compatible with Gemini API."""
|
||||
if subschemas := schema.get("allOf"):
|
||||
for subschema in subschemas: # Gemini API does not support allOf keys
|
||||
if "type" in subschema: # Fallback to first subschema with 'type' field
|
||||
return _format_schema(subschema)
|
||||
return _format_schema(
|
||||
@ -80,42 +106,38 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
result = {}
|
||||
for key, val in schema.items():
|
||||
key = _camel_to_snake(key)
|
||||
if key not in SUPPORTED_SCHEMA_KEYS:
|
||||
continue
|
||||
if key == "any_of":
|
||||
val = [_format_schema(subschema) for subschema in val]
|
||||
if key == "type":
|
||||
key = "type_"
|
||||
val = val.upper()
|
||||
elif key == "format":
|
||||
if schema.get("type") == "string" and val != "enum":
|
||||
continue
|
||||
if schema.get("type") not in ("number", "integer", "string"):
|
||||
continue
|
||||
key = "format_"
|
||||
elif key == "items":
|
||||
if key == "items":
|
||||
val = _format_schema(val)
|
||||
elif key == "properties":
|
||||
val = {k: _format_schema(v) for k, v in val.items()}
|
||||
result[key] = val
|
||||
|
||||
if result.get("enum") and result.get("type_") != "STRING":
|
||||
if result.get("enum") and result.get("type") != "STRING":
|
||||
# enum is only allowed for STRING type. This is safe as long as the schema
|
||||
# contains vol.Coerce for the respective type, for example:
|
||||
# vol.All(vol.Coerce(int), vol.In([1, 2, 3]))
|
||||
result["type_"] = "STRING"
|
||||
result["type"] = "STRING"
|
||||
result["enum"] = [str(item) for item in result["enum"]]
|
||||
|
||||
if result.get("type_") == "OBJECT" and not result.get("properties"):
|
||||
if result.get("type") == "OBJECT" and not result.get("properties"):
|
||||
# An object with undefined properties is not supported by Gemini API.
|
||||
# Fallback to JSON string. This will probably fail for most tools that want it,
|
||||
# but we don't have a better fallback strategy so far.
|
||||
result["properties"] = {"json": {"type_": "STRING"}}
|
||||
result["properties"] = {"json": {"type": "STRING"}}
|
||||
result["required"] = []
|
||||
return result
|
||||
return cast(Schema, result)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
) -> Tool:
|
||||
"""Format tool specification."""
|
||||
|
||||
if tool.parameters.schema:
|
||||
@ -125,16 +147,14 @@ def _format_tool(
|
||||
else:
|
||||
parameters = None
|
||||
|
||||
return protos.Tool(
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
]
|
||||
}
|
||||
return Tool(
|
||||
function_declarations=[
|
||||
FunctionDeclaration(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -151,14 +171,12 @@ def _escape_decode(value: Any) -> Any:
|
||||
|
||||
def _create_google_tool_response_content(
|
||||
content: list[conversation.ToolResultContent],
|
||||
) -> protos.Content:
|
||||
) -> Content:
|
||||
"""Create a Google tool response content."""
|
||||
return protos.Content(
|
||||
return Content(
|
||||
parts=[
|
||||
protos.Part(
|
||||
function_response=protos.FunctionResponse(
|
||||
name=tool_result.tool_name, response=tool_result.tool_result
|
||||
)
|
||||
Part.from_function_response(
|
||||
name=tool_result.tool_name, response=tool_result.tool_result
|
||||
)
|
||||
for tool_result in content
|
||||
]
|
||||
@ -169,33 +187,36 @@ def _convert_content(
|
||||
content: conversation.UserContent
|
||||
| conversation.AssistantContent
|
||||
| conversation.SystemContent,
|
||||
) -> genai_types.ContentDict:
|
||||
) -> Content:
|
||||
"""Convert HA content to Google content."""
|
||||
if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr]
|
||||
role = "model" if content.role == "assistant" else content.role
|
||||
return {"role": role, "parts": content.content}
|
||||
return Content(
|
||||
role=role,
|
||||
parts=[
|
||||
Part.from_text(text=content.content if content.content else ""),
|
||||
],
|
||||
)
|
||||
|
||||
# Handle the Assistant content with tool calls.
|
||||
assert type(content) is conversation.AssistantContent
|
||||
parts = []
|
||||
parts: list[Part] = []
|
||||
|
||||
if content.content:
|
||||
parts.append(protos.Part(text=content.content))
|
||||
parts.append(Part.from_text(text=content.content))
|
||||
|
||||
if content.tool_calls:
|
||||
parts.extend(
|
||||
[
|
||||
protos.Part(
|
||||
function_call=protos.FunctionCall(
|
||||
name=tool_call.tool_name,
|
||||
args=_escape_decode(tool_call.tool_args),
|
||||
)
|
||||
Part.from_function_call(
|
||||
name=tool_call.tool_name,
|
||||
args=_escape_decode(tool_call.tool_args),
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
return protos.Content({"role": "model", "parts": parts})
|
||||
return Content(role="model", parts=parts)
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
@ -209,6 +230,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
def __init__(self, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self._genai_client = entry.runtime_data
|
||||
self._attr_unique_id = entry.entry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
@ -273,7 +295,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
tools: list[Tool | Callable[..., Any]] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
@ -288,13 +310,22 @@ class GoogleGenerativeAIConversationEntity(
|
||||
"gemini-1.0" not in model_name and "gemini-pro" not in model_name
|
||||
)
|
||||
|
||||
prompt = chat_log.content[0].content # type: ignore[union-attr]
|
||||
messages: list[genai_types.ContentDict] = []
|
||||
prompt_content = cast(
|
||||
conversation.SystemContent,
|
||||
chat_log.content[0],
|
||||
)
|
||||
|
||||
if prompt_content.content:
|
||||
prompt = prompt_content.content
|
||||
else:
|
||||
raise HomeAssistantError("Invalid prompt content")
|
||||
|
||||
messages: list[Content] = []
|
||||
|
||||
# Google groups tool results, we do not. Group them before sending.
|
||||
tool_results: list[conversation.ToolResultContent] = []
|
||||
|
||||
for chat_content in chat_log.content[1:]:
|
||||
for chat_content in chat_log.content[1:-1]:
|
||||
if chat_content.role == "tool_result":
|
||||
# mypy doesn't like picking a type based on checking shared property 'role'
|
||||
tool_results.append(cast(conversation.ToolResultContent, chat_content))
|
||||
@ -317,85 +348,93 @@ class GoogleGenerativeAIConversationEntity(
|
||||
|
||||
if tool_results:
|
||||
messages.append(_create_google_tool_response_content(tool_results))
|
||||
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config={
|
||||
"temperature": self.entry.options.get(
|
||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||
generateContentConfig = GenerateContentConfig(
|
||||
temperature=self.entry.options.get(
|
||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||
),
|
||||
top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
|
||||
top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
max_output_tokens=self.entry.options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
),
|
||||
safety_settings=[
|
||||
SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=self.entry.options.get(
|
||||
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
),
|
||||
),
|
||||
"top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
"top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
|
||||
"max_output_tokens": self.entry.options.get(
|
||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||
SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=self.entry.options.get(
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
),
|
||||
),
|
||||
},
|
||||
safety_settings={
|
||||
"HARASSMENT": self.entry.options.get(
|
||||
CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=self.entry.options.get(
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
),
|
||||
),
|
||||
"HATE": self.entry.options.get(
|
||||
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=self.entry.options.get(
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
),
|
||||
),
|
||||
"SEXUAL": self.entry.options.get(
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
),
|
||||
"DANGEROUS": self.entry.options.get(
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
|
||||
),
|
||||
},
|
||||
],
|
||||
tools=tools or None,
|
||||
system_instruction=prompt if supports_system_instruction else None,
|
||||
automatic_function_calling=AutomaticFunctionCallingConfig(
|
||||
disable=True, maximum_remote_calls=None
|
||||
),
|
||||
)
|
||||
|
||||
if not supports_system_instruction:
|
||||
messages = [
|
||||
{"role": "user", "parts": prompt},
|
||||
{"role": "model", "parts": "Ok"},
|
||||
Content(role="user", parts=[Part.from_text(text=prompt)]),
|
||||
Content(role="model", parts=[Part.from_text(text="Ok")]),
|
||||
*messages,
|
||||
]
|
||||
|
||||
chat = model.start_chat(history=messages)
|
||||
chat_request = user_input.text
|
||||
chat = self._genai_client.aio.chats.create(
|
||||
model=model_name, history=messages, config=generateContentConfig
|
||||
)
|
||||
chat_request: str | Content = user_input.text
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
chat_response = await chat.send_message_async(chat_request)
|
||||
except (
|
||||
GoogleAPIError,
|
||||
ValueError,
|
||||
genai_types.BlockedPromptException,
|
||||
genai_types.StopCandidateException,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||
chat_response = await chat.send_message(message=chat_request)
|
||||
|
||||
if isinstance(
|
||||
err, genai_types.StopCandidateException
|
||||
) and "finish_reason: SAFETY\n" in str(err):
|
||||
error = "The message got blocked by your safety settings"
|
||||
else:
|
||||
error = (
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||
if chat_response.prompt_feedback:
|
||||
raise HomeAssistantError(
|
||||
f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}"
|
||||
)
|
||||
|
||||
except (
|
||||
APIError,
|
||||
ValueError,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||
error = f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||
raise HomeAssistantError(error) from err
|
||||
|
||||
LOGGER.debug("Response: %s", chat_response.parts)
|
||||
if not chat_response.parts:
|
||||
response_parts = chat_response.candidates[0].content.parts
|
||||
if not response_parts:
|
||||
raise HomeAssistantError(
|
||||
"Sorry, I had a problem getting a response from Google Generative AI."
|
||||
)
|
||||
content = " ".join(
|
||||
[part.text.strip() for part in chat_response.parts if part.text]
|
||||
[part.text.strip() for part in response_parts if part.text]
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for part in chat_response.parts:
|
||||
for part in response_parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = _escape_decode(tool_call["args"])
|
||||
tool_call = part.function_call
|
||||
tool_name = tool_call.name
|
||||
tool_args = _escape_decode(tool_call.args)
|
||||
tool_calls.append(
|
||||
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
)
|
||||
@ -418,7 +457,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
response.async_set_speech(
|
||||
" ".join([part.text.strip() for part in chat_response.parts if part.text])
|
||||
" ".join([part.text.strip() for part in response_parts if part.text])
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=response, conversation_id=chat_log.conversation_id
|
||||
|
@ -8,5 +8,5 @@
|
||||
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"requirements": ["google-generativeai==0.8.2"]
|
||||
"requirements": ["google-genai==1.1.0"]
|
||||
}
|
||||
|
2
requirements_all.txt
generated
2
requirements_all.txt
generated
@ -1033,7 +1033,7 @@ google-cloud-speech==2.27.0
|
||||
google-cloud-texttospeech==2.17.2
|
||||
|
||||
# homeassistant.components.google_generative_ai_conversation
|
||||
google-generativeai==0.8.2
|
||||
google-genai==1.1.0
|
||||
|
||||
# homeassistant.components.nest
|
||||
google-nest-sdm==7.1.3
|
||||
|
2
requirements_test_all.txt
generated
2
requirements_test_all.txt
generated
@ -883,7 +883,7 @@ google-cloud-speech==2.27.0
|
||||
google-cloud-texttospeech==2.17.2
|
||||
|
||||
# homeassistant.components.google_generative_ai_conversation
|
||||
google-generativeai==0.8.2
|
||||
google-genai==1.1.0
|
||||
|
||||
# homeassistant.components.nest
|
||||
google-nest-sdm==7.1.3
|
||||
|
@ -1 +1,31 @@
|
||||
"""Tests for the Google Generative AI Conversation integration."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.genai.errors import ClientError
|
||||
import requests
|
||||
|
||||
CLIENT_ERROR_500 = ClientError(
|
||||
500,
|
||||
Mock(
|
||||
__class__=requests.Response,
|
||||
json=Mock(
|
||||
return_value={
|
||||
"message": "Internal Server Error",
|
||||
"status": "internal-error",
|
||||
}
|
||||
),
|
||||
),
|
||||
)
|
||||
CLIENT_ERROR_API_KEY_INVALID = ClientError(
|
||||
400,
|
||||
Mock(
|
||||
__class__=requests.Response,
|
||||
json=Mock(
|
||||
return_value={
|
||||
"message": "'reason': API_KEY_INVALID",
|
||||
"status": "unauthorized",
|
||||
}
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Tests helpers."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,14 +14,7 @@ from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai() -> Generator[None]:
|
||||
"""Mock the genai call in async_setup_entry."""
|
||||
with patch("google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.get_model"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry:
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"""Mock a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="google_generative_ai_conversation",
|
||||
@ -31,18 +23,21 @@ def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry:
|
||||
"api_key": "bla",
|
||||
},
|
||||
)
|
||||
entry.runtime_data = Mock()
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_assist(
|
||||
async def mock_config_entry_with_assist(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with assist."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
||||
)
|
||||
with patch("google.genai.models.AsyncModels.get"):
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@ -51,8 +46,11 @@ async def mock_init_component(
|
||||
hass: HomeAssistant, mock_config_entry: ConfigEntry
|
||||
) -> None:
|
||||
"""Initialize integration."""
|
||||
assert await async_setup_component(hass, "google_generative_ai_conversation", {})
|
||||
await hass.async_block_till_done()
|
||||
with patch("google.genai.models.AsyncModels.get"):
|
||||
assert await async_setup_component(
|
||||
hass, "google_generative_ai_conversation", {}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -6,106 +6,26 @@
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-2.0-flash',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'system_instruction': '''
|
||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||
You are a voice assistant for Home Assistant.
|
||||
Answer questions about the world truthfully.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||
''',
|
||||
'tools': list([
|
||||
function_declarations {
|
||||
name: "test_tool"
|
||||
description: "Test function"
|
||||
parameters {
|
||||
type_: OBJECT
|
||||
properties {
|
||||
key: "param3"
|
||||
value {
|
||||
type_: OBJECT
|
||||
properties {
|
||||
key: "json"
|
||||
value {
|
||||
type_: STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
properties {
|
||||
key: "param2"
|
||||
value {
|
||||
type_: NUMBER
|
||||
}
|
||||
}
|
||||
properties {
|
||||
key: "param1"
|
||||
value {
|
||||
type_: ARRAY
|
||||
description: "Test parameters"
|
||||
items {
|
||||
type_: STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
,
|
||||
]),
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.ARRAY: 'ARRAY'>, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format='lower', items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=[Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.NUMBER: 'NUMBER'>, description=None, enum=None, format=None, items=None, properties=None, required=None), Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.INTEGER: 'INTEGER'>, description=None, enum=None, format=None, items=None, properties=None, required=None)], max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
'().send_message',
|
||||
tuple(
|
||||
'Please call the test function',
|
||||
),
|
||||
dict({
|
||||
'message': 'Please call the test function',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
'().send_message',
|
||||
tuple(
|
||||
parts {
|
||||
function_response {
|
||||
name: "test_tool"
|
||||
response {
|
||||
fields {
|
||||
key: "result"
|
||||
value {
|
||||
string_value: "Test response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
,
|
||||
),
|
||||
dict({
|
||||
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
|
||||
}),
|
||||
),
|
||||
])
|
||||
@ -117,75 +37,26 @@
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-2.0-flash',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'system_instruction': '''
|
||||
Current time is 05:00:00. Today's date is 2024-05-24.
|
||||
You are a voice assistant for Home Assistant.
|
||||
Answer questions about the world truthfully.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||
''',
|
||||
'tools': list([
|
||||
function_declarations {
|
||||
name: "test_tool"
|
||||
description: "Test function"
|
||||
}
|
||||
,
|
||||
]),
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
'().send_message',
|
||||
tuple(
|
||||
'Please call the test function',
|
||||
),
|
||||
dict({
|
||||
'message': 'Please call the test function',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
'().send_message',
|
||||
tuple(
|
||||
parts {
|
||||
function_response {
|
||||
name: "test_tool"
|
||||
response {
|
||||
fields {
|
||||
key: "result"
|
||||
value {
|
||||
string_value: "Test response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
,
|
||||
),
|
||||
dict({
|
||||
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
|
||||
}),
|
||||
),
|
||||
])
|
||||
|
@ -6,21 +6,11 @@
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'model_name': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().generate_content_async',
|
||||
tuple(
|
||||
list([
|
||||
'contents': list([
|
||||
'Describe this image from my doorbell camera',
|
||||
dict({
|
||||
'data': b'image bytes',
|
||||
'mime_type': 'image/jpeg',
|
||||
}),
|
||||
b'image bytes',
|
||||
]),
|
||||
),
|
||||
dict({
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
])
|
||||
@ -32,17 +22,10 @@
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'model_name': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().generate_content_async',
|
||||
tuple(
|
||||
list([
|
||||
'contents': list([
|
||||
'Write an opening speech for a Home Assistant release party',
|
||||
]),
|
||||
),
|
||||
dict({
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
])
|
||||
|
@ -1,10 +1,9 @@
|
||||
"""Test the Google Generative AI Conversation config flow."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from google.api_core.exceptions import ClientError, DeadlineExceeded
|
||||
from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module
|
||||
import pytest
|
||||
from requests.exceptions import Timeout
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.google_generative_ai_conversation.config_flow import (
|
||||
@ -33,6 +32,8 @@ from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@ -41,30 +42,37 @@ def mock_models():
|
||||
"""Mock the model list API."""
|
||||
model_20_flash = Mock(
|
||||
display_name="Gemini 2.0 Flash",
|
||||
supported_generation_methods=["generateContent"],
|
||||
supported_actions=["generateContent"],
|
||||
)
|
||||
model_20_flash.name = "models/gemini-2.0-flash"
|
||||
|
||||
model_15_flash = Mock(
|
||||
display_name="Gemini 1.5 Flash",
|
||||
supported_generation_methods=["generateContent"],
|
||||
supported_actions=["generateContent"],
|
||||
)
|
||||
model_15_flash.name = "models/gemini-1.5-flash-latest"
|
||||
|
||||
model_15_pro = Mock(
|
||||
display_name="Gemini 1.5 Pro",
|
||||
supported_generation_methods=["generateContent"],
|
||||
supported_actions=["generateContent"],
|
||||
)
|
||||
model_15_pro.name = "models/gemini-1.5-pro-latest"
|
||||
|
||||
model_10_pro = Mock(
|
||||
display_name="Gemini 1.0 Pro",
|
||||
supported_generation_methods=["generateContent"],
|
||||
supported_actions=["generateContent"],
|
||||
)
|
||||
model_10_pro.name = "models/gemini-pro"
|
||||
|
||||
async def models_pager():
|
||||
yield model_20_flash
|
||||
yield model_15_flash
|
||||
yield model_15_pro
|
||||
yield model_10_pro
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
|
||||
return_value=iter([model_20_flash, model_15_flash, model_15_pro, model_10_pro]),
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=models_pager(),
|
||||
):
|
||||
yield
|
||||
|
||||
@ -86,7 +94,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models",
|
||||
"google.genai.models.AsyncModels.list",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
|
||||
@ -170,7 +178,11 @@ async def test_options_switching(
|
||||
expected_options,
|
||||
) -> None:
|
||||
"""Test the options form."""
|
||||
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
|
||||
with patch("google.genai.models.AsyncModels.get"):
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options=current_options
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
@ -195,17 +207,15 @@ async def test_options_switching(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(
|
||||
ClientError("some error"),
|
||||
CLIENT_ERROR_500,
|
||||
"cannot_connect",
|
||||
),
|
||||
(
|
||||
DeadlineExceeded("deadline exceeded"),
|
||||
Timeout("deadline exceeded"),
|
||||
"cannot_connect",
|
||||
),
|
||||
(
|
||||
ClientError(
|
||||
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
|
||||
),
|
||||
CLIENT_ERROR_API_KEY_INVALID,
|
||||
"invalid_auth",
|
||||
),
|
||||
(Exception, "unknown"),
|
||||
@ -217,12 +227,7 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_models.side_effect = side_effect
|
||||
with patch(
|
||||
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
with patch("google.genai.models.AsyncModels.list", side_effect=side_effect):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
@ -259,7 +264,7 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models",
|
||||
"google.genai.models.AsyncModels.list",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
|
||||
|
@ -1,12 +1,10 @@
|
||||
"""Tests for the Google Generative AI Conversation integration conversation platform."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from google.ai.generativelanguage_v1beta.types.content import FunctionCall
|
||||
from google.api_core.exceptions import GoogleAPIError
|
||||
import google.generativeai.types as genai_types
|
||||
from google.genai.types import FunctionCall
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
@ -22,6 +20,8 @@ from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
from . import CLIENT_ERROR_500
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ async def test_function_call(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function calling."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
@ -69,12 +69,12 @@ async def test_function_call(
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(
|
||||
name="test_tool",
|
||||
@ -92,7 +92,7 @@ async def test_function_call(
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
@ -104,20 +104,28 @@ async def test_function_call(
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
|
||||
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
|
||||
assert mock_tool_call == {
|
||||
mock_tool_call = mock_create.mock_calls[2][2]["message"]
|
||||
assert mock_tool_call.model_dump() == {
|
||||
"parts": [
|
||||
{
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"result": "Test response",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
},
|
||||
],
|
||||
"role": "",
|
||||
"role": None,
|
||||
}
|
||||
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
@ -139,7 +147,7 @@ async def test_function_call(
|
||||
device_id="test_device",
|
||||
),
|
||||
)
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||
|
||||
# Test conversating tracing
|
||||
traces = trace.async_get_traces()
|
||||
@ -170,7 +178,7 @@ async def test_function_call_without_parameters(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function calling without parameters."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
@ -180,12 +188,12 @@ async def test_function_call_without_parameters(
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={})
|
||||
|
||||
@ -197,7 +205,7 @@ async def test_function_call_without_parameters(
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
@ -209,20 +217,28 @@ async def test_function_call_without_parameters(
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
|
||||
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
|
||||
assert mock_tool_call == {
|
||||
mock_tool_call = mock_create.mock_calls[2][2]["message"]
|
||||
assert mock_tool_call.model_dump() == {
|
||||
"parts": [
|
||||
{
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"result": "Test response",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
},
|
||||
],
|
||||
"role": "",
|
||||
"role": None,
|
||||
}
|
||||
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
@ -241,7 +257,7 @@ async def test_function_call_without_parameters(
|
||||
device_id="test_device",
|
||||
),
|
||||
)
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||
|
||||
|
||||
@patch(
|
||||
@ -254,7 +270,7 @@ async def test_function_exception(
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test exception in function calling."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
@ -270,12 +286,12 @@ async def test_function_exception(
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
||||
|
||||
@ -287,7 +303,7 @@ async def test_function_exception(
|
||||
raise HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
@ -299,21 +315,29 @@ async def test_function_exception(
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
|
||||
mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0]
|
||||
mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call)
|
||||
assert mock_tool_call == {
|
||||
mock_tool_call = mock_create.mock_calls[2][2]["message"]
|
||||
assert mock_tool_call.model_dump() == {
|
||||
"parts": [
|
||||
{
|
||||
"code_execution_result": None,
|
||||
"executable_code": None,
|
||||
"file_data": None,
|
||||
"function_call": None,
|
||||
"function_response": {
|
||||
"id": None,
|
||||
"name": "test_tool",
|
||||
"response": {
|
||||
"error": "HomeAssistantError",
|
||||
"error_text": "Test tool exception",
|
||||
},
|
||||
},
|
||||
"inline_data": None,
|
||||
"text": None,
|
||||
"thought": None,
|
||||
"video_metadata": None,
|
||||
},
|
||||
],
|
||||
"role": "",
|
||||
"role": None,
|
||||
}
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
@ -338,18 +362,22 @@ async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that client errors are caught."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
mock_chat.send_message_async.side_effect = GoogleAPIError("some error")
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
mock_chat.side_effect = CLIENT_ERROR_500
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Sorry, I had a problem talking to Google Generative AI: some error"
|
||||
"Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}"
|
||||
)
|
||||
|
||||
|
||||
@ -358,20 +386,24 @@ async def test_blocked_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test blocked response."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
|
||||
"finish_reason: SAFETY\n"
|
||||
)
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY"))
|
||||
mock_chat.return_value = chat_response
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"The message got blocked by your safety settings"
|
||||
"The message got blocked due to content violations, reason: SAFETY"
|
||||
)
|
||||
|
||||
|
||||
@ -380,14 +412,18 @@ async def test_empty_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
chat_response.parts = []
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[]))]
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
@ -402,17 +438,19 @@ async def test_converse_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test handling ChatLog raising ConverseError."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
|
||||
)
|
||||
with patch("google.genai.models.AsyncModels.get"):
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
agent_id="conversation.google_generative_ai_conversation",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
@ -449,31 +487,39 @@ async def test_escape_decode() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("openapi", "protobuf"),
|
||||
("openapi", "genai_schema"),
|
||||
[
|
||||
(
|
||||
{"type": "string", "enum": ["a", "b", "c"]},
|
||||
{"type_": "STRING", "enum": ["a", "b", "c"]},
|
||||
{"type": "STRING", "enum": ["a", "b", "c"]},
|
||||
),
|
||||
(
|
||||
{"type": "integer", "enum": [1, 2, 3]},
|
||||
{"type_": "STRING", "enum": ["1", "2", "3"]},
|
||||
{"type": "STRING", "enum": ["1", "2", "3"]},
|
||||
),
|
||||
(
|
||||
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
|
||||
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
|
||||
),
|
||||
({"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"type_": "INTEGER"}),
|
||||
(
|
||||
{
|
||||
"anyOf": [
|
||||
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
|
||||
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
|
||||
"any_of": [
|
||||
{"any_of": [{"type": "integer"}, {"type": "number"}]},
|
||||
{"any_of": [{"type": "integer"}, {"type": "number"}]},
|
||||
]
|
||||
},
|
||||
{
|
||||
"any_of": [
|
||||
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
|
||||
{"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]},
|
||||
]
|
||||
},
|
||||
{"type_": "INTEGER"},
|
||||
),
|
||||
({"type": "string", "format": "lower"}, {"type_": "STRING"}),
|
||||
({"type": "boolean", "format": "bool"}, {"type_": "BOOLEAN"}),
|
||||
({"type": "string", "format": "lower"}, {"format": "lower", "type": "STRING"}),
|
||||
({"type": "boolean", "format": "bool"}, {"format": "bool", "type": "BOOLEAN"}),
|
||||
(
|
||||
{"type": "number", "format": "percent"},
|
||||
{"type_": "NUMBER", "format_": "percent"},
|
||||
{"type": "NUMBER", "format": "percent"},
|
||||
),
|
||||
(
|
||||
{
|
||||
@ -482,25 +528,25 @@ async def test_escape_decode() -> None:
|
||||
"required": [],
|
||||
},
|
||||
{
|
||||
"type_": "OBJECT",
|
||||
"properties": {"var": {"type_": "STRING"}},
|
||||
"type": "OBJECT",
|
||||
"properties": {"var": {"type": "STRING"}},
|
||||
"required": [],
|
||||
},
|
||||
),
|
||||
(
|
||||
{"type": "object", "additionalProperties": True},
|
||||
{
|
||||
"type_": "OBJECT",
|
||||
"properties": {"json": {"type_": "STRING"}},
|
||||
"type": "OBJECT",
|
||||
"properties": {"json": {"type": "STRING"}},
|
||||
"required": [],
|
||||
},
|
||||
),
|
||||
(
|
||||
{"type": "array", "items": {"type": "string"}},
|
||||
{"type_": "ARRAY", "items": {"type_": "STRING"}},
|
||||
{"type": "ARRAY", "items": {"type": "STRING"}},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_format_schema(openapi, protobuf) -> None:
|
||||
async def test_format_schema(openapi, genai_schema) -> None:
|
||||
"""Test _format_schema."""
|
||||
assert _format_schema(openapi) == protobuf
|
||||
assert _format_schema(openapi) == genai_schema
|
||||
|
@ -1,16 +1,17 @@
|
||||
"""Tests for the Google Generative AI Conversation integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from google.api_core.exceptions import ClientError, DeadlineExceeded
|
||||
from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module
|
||||
import pytest
|
||||
from requests.exceptions import Timeout
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@ -24,12 +25,14 @@ async def test_generate_content_service_without_images(
|
||||
"party for the latest version of Home Assistant!"
|
||||
)
|
||||
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = stubbed_generated_content
|
||||
mock_model.return_value.generate_content_async = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
return_value=Mock(
|
||||
text=stubbed_generated_content,
|
||||
prompt_feedback=None,
|
||||
candidates=[Mock()],
|
||||
),
|
||||
) as mock_generate:
|
||||
response = await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
@ -41,7 +44,7 @@ async def test_generate_content_service_without_images(
|
||||
assert response == {
|
||||
"text": stubbed_generated_content,
|
||||
}
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@ -54,19 +57,21 @@ async def test_generate_content_service_with_image(
|
||||
)
|
||||
|
||||
with (
|
||||
patch("google.generativeai.GenerativeModel") as mock_model,
|
||||
patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.Path.read_bytes",
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
return_value=Mock(
|
||||
text=stubbed_generated_content,
|
||||
prompt_feedback=None,
|
||||
candidates=[Mock()],
|
||||
),
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.Image.open",
|
||||
return_value=b"image bytes",
|
||||
),
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch.object(hass.config, "is_allowed_path", return_value=True),
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = stubbed_generated_content
|
||||
mock_model.return_value.generate_content_async = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
response = await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
@ -81,7 +86,7 @@ async def test_generate_content_service_with_image(
|
||||
assert response == {
|
||||
"text": stubbed_generated_content,
|
||||
}
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@ -90,20 +95,23 @@ async def test_generate_content_service_error(
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test generate content service handles errors."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_model.return_value.generate_content_async = AsyncMock(
|
||||
side_effect=ClientError("reason")
|
||||
with (
|
||||
patch(
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
side_effect=CLIENT_ERROR_500,
|
||||
),
|
||||
pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Error generating content: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}",
|
||||
),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
{"prompt": "write a story about an epic fail"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="Error generating content: None reason"
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
{"prompt": "write a story about an epic fail"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@ -113,21 +121,22 @@ async def test_generate_content_response_has_empty_parts(
|
||||
) -> None:
|
||||
"""Test generate content service handles response with empty parts."""
|
||||
with (
|
||||
patch("google.generativeai.GenerativeModel") as mock_model,
|
||||
patch(
|
||||
"google.genai.models.AsyncModels.generate_content",
|
||||
return_value=Mock(
|
||||
prompt_feedback=None,
|
||||
candidates=[Mock(content=Mock(parts=[]))],
|
||||
),
|
||||
),
|
||||
pytest.raises(HomeAssistantError, match="Unknown error generating content"),
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.parts = []
|
||||
mock_model.return_value.generate_content_async = AsyncMock(
|
||||
return_value=mock_response
|
||||
await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
{"prompt": "write a story about an epic fail"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
with pytest.raises(HomeAssistantError, match="Error generating content"):
|
||||
await hass.services.async_call(
|
||||
"google_generative_ai_conversation",
|
||||
"generate_content",
|
||||
{"prompt": "write a story about an epic fail"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@ -211,19 +220,17 @@ async def test_generate_content_service_with_non_image(hass: HomeAssistant) -> N
|
||||
("side_effect", "state", "reauth"),
|
||||
[
|
||||
(
|
||||
ClientError("some error"),
|
||||
CLIENT_ERROR_500,
|
||||
ConfigEntryState.SETUP_ERROR,
|
||||
False,
|
||||
),
|
||||
(
|
||||
DeadlineExceeded("deadline exceeded"),
|
||||
Timeout,
|
||||
ConfigEntryState.SETUP_RETRY,
|
||||
False,
|
||||
),
|
||||
(
|
||||
ClientError(
|
||||
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
|
||||
),
|
||||
CLIENT_ERROR_API_KEY_INVALID,
|
||||
ConfigEntryState.SETUP_ERROR,
|
||||
True,
|
||||
),
|
||||
@ -235,10 +242,7 @@ async def test_config_entry_error(
|
||||
"""Test different configuration entry errors."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get_model.side_effect = side_effect
|
||||
with patch(
|
||||
"google.ai.generativelanguage_v1beta.ModelServiceAsyncClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
with patch("google.genai.models.AsyncModels.get", side_effect=side_effect):
|
||||
assert not await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert mock_config_entry.state == state
|
||||
|
Loading…
x
Reference in New Issue
Block a user