diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index a5c55c2099d..e9ab5cbdd3e 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -5,11 +5,10 @@ from __future__ import annotations import mimetypes from pathlib import Path -from google.ai import generativelanguage_v1beta -from google.api_core.client_options import ClientOptions -from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPIError -import google.generativeai as genai -import google.generativeai.types as genai_types +from google import genai # type: ignore[attr-defined] +from google.genai.errors import APIError, ClientError +from PIL import Image +from requests.exceptions import Timeout import voluptuous as vol from homeassistant.config_entries import ConfigEntry @@ -29,7 +28,13 @@ from homeassistant.exceptions import ( from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL +from .const import ( + CONF_CHAT_MODEL, + CONF_PROMPT, + DOMAIN, + RECOMMENDED_CHAT_MODEL, + TIMEOUT_MILLIS, +) SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" @@ -37,6 +42,8 @@ CONF_IMAGE_FILENAME = "image_filename" CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) PLATFORMS = (Platform.CONVERSATION,) +type GoogleGenerativeAIConfigEntry = ConfigEntry[genai.Client] + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Google Generative AI Conversation.""" @@ -44,42 +51,47 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def generate_content(call: ServiceCall) -> ServiceResponse: """Generate content from text and optionally images.""" prompt_parts = [call.data[CONF_PROMPT]] - image_filenames = call.data[CONF_IMAGE_FILENAME] - for image_filename in image_filenames: - if not hass.config.is_allowed_path(image_filename): - raise HomeAssistantError( - f"Cannot read `{image_filename}`, no access to path; " - "`allowlist_external_dirs` may need to be adjusted in " - "`configuration.yaml`" - ) - if not Path(image_filename).exists(): - raise HomeAssistantError(f"`{image_filename}` does not exist") - mime_type, _ = mimetypes.guess_type(image_filename) - if mime_type is None or not mime_type.startswith("image"): - raise HomeAssistantError(f"`{image_filename}` is not an image") - prompt_parts.append( - { - "mime_type": mime_type, - "data": await hass.async_add_executor_job( - Path(image_filename).read_bytes - ), - } - ) - model = genai.GenerativeModel(model_name=RECOMMENDED_CHAT_MODEL) + def append_images_to_prompt(): + image_filenames = call.data[CONF_IMAGE_FILENAME] + for image_filename in image_filenames: + if not hass.config.is_allowed_path(image_filename): + raise HomeAssistantError( + f"Cannot read `{image_filename}`, no access to path; " + "`allowlist_external_dirs` may need to be adjusted in " + "`configuration.yaml`" + ) + if not Path(image_filename).exists(): + raise HomeAssistantError(f"`{image_filename}` does not exist") + mime_type, _ = mimetypes.guess_type(image_filename) + if mime_type is None or not mime_type.startswith("image"): + raise HomeAssistantError(f"`{image_filename}` is not an image") + prompt_parts.append(Image.open(image_filename)) + + await hass.async_add_executor_job(append_images_to_prompt) + + config_entry: GoogleGenerativeAIConfigEntry = hass.config_entries.async_entries( + DOMAIN + )[0] + client = config_entry.runtime_data try: - response = await model.generate_content_async(prompt_parts) + response = await client.aio.models.generate_content( + model=RECOMMENDED_CHAT_MODEL, contents=prompt_parts + ) except ( - GoogleAPIError, + APIError, ValueError, - genai_types.BlockedPromptException, - genai_types.StopCandidateException, ) as err: raise HomeAssistantError(f"Error generating content: {err}") from err - if not response.parts: - raise HomeAssistantError("Error generating content") + if response.prompt_feedback: + raise HomeAssistantError( + f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}" + ) + + if not response.candidates[0].content.parts: + raise HomeAssistantError("Unknown error generating content") return {"text": response.text} @@ -100,30 +112,34 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry( + hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry +) -> bool: """Set up Google Generative AI Conversation from a config entry.""" - genai.configure(api_key=entry.data[CONF_API_KEY]) try: - client = generativelanguage_v1beta.ModelServiceAsyncClient( - client_options=ClientOptions(api_key=entry.data[CONF_API_KEY]) + client = genai.Client(api_key=entry.data[CONF_API_KEY]) + await client.aio.models.get( + model=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + config={"http_options": {"timeout": TIMEOUT_MILLIS}}, ) - await client.get_model( - name=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), timeout=5.0 - ) - except (GoogleAPIError, ValueError) as err: - if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": - raise ConfigEntryAuthFailed(err) from err - if isinstance(err, DeadlineExceeded): + except (APIError, Timeout) as err: + if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err): + raise ConfigEntryAuthFailed(err.message) from err + if isinstance(err, Timeout): raise ConfigEntryNotReady(err) from err raise ConfigEntryError(err) from err + else: + entry.runtime_data = client await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_unload_entry( + hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry +) -> bool: """Unload GoogleGenerativeAI.""" if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return False diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index 83eec25ed15..00a016143f4 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -3,15 +3,13 @@ from __future__ import annotations from collections.abc import Mapping -from functools import partial import logging from types import MappingProxyType from typing import Any -from google.ai import generativelanguage_v1beta -from google.api_core.client_options import ClientOptions -from google.api_core.exceptions import ClientError, GoogleAPIError -import google.generativeai as genai +from google import genai # type: ignore[attr-defined] +from google.genai.errors import APIError, ClientError +from requests.exceptions import Timeout import voluptuous as vol from homeassistant.config_entries import ( @@ -53,6 +51,7 @@ from .const import ( RECOMMENDED_TEMPERATURE, RECOMMENDED_TOP_K, RECOMMENDED_TOP_P, + TIMEOUT_MILLIS, ) _LOGGER = logging.getLogger(__name__) @@ -70,15 +69,20 @@ RECOMMENDED_OPTIONS = { } -async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: +async def validate_input(data: dict[str, Any]) -> None: """Validate the user input allows us to connect. Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. """ - client = generativelanguage_v1beta.ModelServiceAsyncClient( - client_options=ClientOptions(api_key=data[CONF_API_KEY]) + client = genai.Client(api_key=data[CONF_API_KEY]) + await client.aio.models.list( + config={ + "http_options": { + "timeout": TIMEOUT_MILLIS, + }, + "query_base": True, + } ) - await client.list_models(timeout=5.0) class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): @@ -93,9 +97,9 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} if user_input is not None: try: - await validate_input(self.hass, user_input) - except GoogleAPIError as err: - if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": + await validate_input(user_input) + except (APIError, Timeout) as err: + if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err): errors["base"] = "invalid_auth" else: errors["base"] = "cannot_connect" @@ -166,6 +170,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): self.last_rendered_recommended = config_entry.options.get( CONF_RECOMMENDED, False ) + self._genai_client = config_entry.runtime_data async def async_step_init( self, user_input: dict[str, Any] | None = None @@ -188,7 +193,9 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], } - schema = await google_generative_ai_config_option_schema(self.hass, options) + schema = await google_generative_ai_config_option_schema( + self.hass, options, self._genai_client + ) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), @@ -198,6 +205,7 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): async def google_generative_ai_config_option_schema( hass: HomeAssistant, options: dict[str, Any] | MappingProxyType[str, Any], + genai_client: genai.Client, ) -> dict: """Return a schema for Google Generative AI completion options.""" hass_apis: list[SelectOptionDict] = [ @@ -236,18 +244,21 @@ async def google_generative_ai_config_option_schema( if options.get(CONF_RECOMMENDED): return schema - api_models = await hass.async_add_executor_job(partial(genai.list_models)) - + api_models_pager = await genai_client.aio.models.list(config={"query_base": True}) + api_models = [api_model async for api_model in api_models_pager] models = [ SelectOptionDict( label=api_model.display_name, value=api_model.name, ) - for api_model in sorted(api_models, key=lambda x: x.display_name) + for api_model in sorted(api_models, key=lambda x: x.display_name or "") if ( api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro + and api_model.display_name + and api_model.name + and api_model.supported_actions and "vision" not in api_model.name - and "generateContent" in api_model.supported_generation_methods + and "generateContent" in api_model.supported_actions ) ] diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 4d83b935528..35834f6e7f9 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -22,3 +22,5 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold" CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold" CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold" RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE" + +TIMEOUT_MILLIS = 10000 diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 4e0dc92f140..c99c4c07a7d 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -6,11 +6,18 @@ import codecs from collections.abc import Callable from typing import Any, Literal, cast -from google.api_core.exceptions import GoogleAPIError -import google.generativeai as genai -from google.generativeai import protos -import google.generativeai.types as genai_types -from google.protobuf.json_format import MessageToDict +from google.genai.errors import APIError +from google.genai.types import ( + AutomaticFunctionCallingConfig, + Content, + FunctionDeclaration, + GenerateContentConfig, + HarmCategory, + Part, + SafetySetting, + Schema, + Tool, +) from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation @@ -57,21 +64,40 @@ async def async_setup_entry( SUPPORTED_SCHEMA_KEYS = { - "type", - "format", - "description", + "min_items", + "example", + "property_ordering", + "pattern", + "minimum", + "default", + "any_of", + "max_length", + "title", + "min_properties", + "min_length", + "max_items", + "maximum", "nullable", + "max_properties", + "type", + "description", "enum", + "format", "items", "properties", "required", } -def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: - """Format the schema to protobuf.""" - if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")): - for subschema in subschemas: # Gemini API does not support anyOf and allOf keys +def _camel_to_snake(name: str) -> str: + """Convert camel case to snake case.""" + return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") + + +def _format_schema(schema: dict[str, Any]) -> Schema: + """Format the schema to be compatible with Gemini API.""" + if subschemas := schema.get("allOf"): + for subschema in subschemas: # Gemini API does not support allOf keys if "type" in subschema: # Fallback to first subschema with 'type' field return _format_schema(subschema) return _format_schema( @@ -80,42 +106,38 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: result = {} for key, val in schema.items(): + key = _camel_to_snake(key) if key not in SUPPORTED_SCHEMA_KEYS: continue + if key == "any_of": + val = [_format_schema(subschema) for subschema in val] if key == "type": - key = "type_" val = val.upper() - elif key == "format": - if schema.get("type") == "string" and val != "enum": - continue - if schema.get("type") not in ("number", "integer", "string"): - continue - key = "format_" - elif key == "items": + if key == "items": val = _format_schema(val) elif key == "properties": val = {k: _format_schema(v) for k, v in val.items()} result[key] = val - if result.get("enum") and result.get("type_") != "STRING": + if result.get("enum") and result.get("type") != "STRING": # enum is only allowed for STRING type. This is safe as long as the schema # contains vol.Coerce for the respective type, for example: # vol.All(vol.Coerce(int), vol.In([1, 2, 3])) - result["type_"] = "STRING" + result["type"] = "STRING" result["enum"] = [str(item) for item in result["enum"]] - if result.get("type_") == "OBJECT" and not result.get("properties"): + if result.get("type") == "OBJECT" and not result.get("properties"): # An object with undefined properties is not supported by Gemini API. # Fallback to JSON string. This will probably fail for most tools that want it, # but we don't have a better fallback strategy so far. - result["properties"] = {"json": {"type_": "STRING"}} + result["properties"] = {"json": {"type": "STRING"}} result["required"] = [] - return result + return cast(Schema, result) def _format_tool( tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None -) -> dict[str, Any]: +) -> Tool: """Format tool specification.""" if tool.parameters.schema: @@ -125,16 +147,14 @@ def _format_tool( else: parameters = None - return protos.Tool( - { - "function_declarations": [ - { - "name": tool.name, - "description": tool.description, - "parameters": parameters, - } - ] - } + return Tool( + function_declarations=[ + FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters=parameters, + ) + ] ) @@ -151,14 +171,12 @@ def _escape_decode(value: Any) -> Any: def _create_google_tool_response_content( content: list[conversation.ToolResultContent], -) -> protos.Content: +) -> Content: """Create a Google tool response content.""" - return protos.Content( + return Content( parts=[ - protos.Part( - function_response=protos.FunctionResponse( - name=tool_result.tool_name, response=tool_result.tool_result - ) + Part.from_function_response( + name=tool_result.tool_name, response=tool_result.tool_result ) for tool_result in content ] @@ -169,33 +187,36 @@ def _convert_content( content: conversation.UserContent | conversation.AssistantContent | conversation.SystemContent, -) -> genai_types.ContentDict: +) -> Content: """Convert HA content to Google content.""" if content.role != "assistant" or not content.tool_calls: # type: ignore[union-attr] role = "model" if content.role == "assistant" else content.role - return {"role": role, "parts": content.content} + return Content( + role=role, + parts=[ + Part.from_text(text=content.content if content.content else ""), + ], + ) # Handle the Assistant content with tool calls. assert type(content) is conversation.AssistantContent - parts = [] + parts: list[Part] = [] if content.content: - parts.append(protos.Part(text=content.content)) + parts.append(Part.from_text(text=content.content)) if content.tool_calls: parts.extend( [ - protos.Part( - function_call=protos.FunctionCall( - name=tool_call.tool_name, - args=_escape_decode(tool_call.tool_args), - ) + Part.from_function_call( + name=tool_call.tool_name, + args=_escape_decode(tool_call.tool_args), ) for tool_call in content.tool_calls ] ) - return protos.Content({"role": "model", "parts": parts}) + return Content(role="model", parts=parts) class GoogleGenerativeAIConversationEntity( @@ -209,6 +230,7 @@ class GoogleGenerativeAIConversationEntity( def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" self.entry = entry + self._genai_client = entry.runtime_data self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, @@ -273,7 +295,7 @@ class GoogleGenerativeAIConversationEntity( except conversation.ConverseError as err: return err.as_conversation_result() - tools: list[dict[str, Any]] | None = None + tools: list[Tool | Callable[..., Any]] | None = None if chat_log.llm_api: tools = [ _format_tool(tool, chat_log.llm_api.custom_serializer) @@ -288,13 +310,22 @@ class GoogleGenerativeAIConversationEntity( "gemini-1.0" not in model_name and "gemini-pro" not in model_name ) - prompt = chat_log.content[0].content # type: ignore[union-attr] - messages: list[genai_types.ContentDict] = [] + prompt_content = cast( + conversation.SystemContent, + chat_log.content[0], + ) + + if prompt_content.content: + prompt = prompt_content.content + else: + raise HomeAssistantError("Invalid prompt content") + + messages: list[Content] = [] # Google groups tool results, we do not. Group them before sending. tool_results: list[conversation.ToolResultContent] = [] - for chat_content in chat_log.content[1:]: + for chat_content in chat_log.content[1:-1]: if chat_content.role == "tool_result": # mypy doesn't like picking a type based on checking shared property 'role' tool_results.append(cast(conversation.ToolResultContent, chat_content)) @@ -317,85 +348,93 @@ class GoogleGenerativeAIConversationEntity( if tool_results: messages.append(_create_google_tool_response_content(tool_results)) - - model = genai.GenerativeModel( - model_name=model_name, - generation_config={ - "temperature": self.entry.options.get( - CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE + generateContentConfig = GenerateContentConfig( + temperature=self.entry.options.get( + CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE + ), + top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K), + top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + max_output_tokens=self.entry.options.get( + CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS + ), + safety_settings=[ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=self.entry.options.get( + CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), ), - "top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), - "top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K), - "max_output_tokens": self.entry.options.get( - CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=self.entry.options.get( + CONF_HARASSMENT_BLOCK_THRESHOLD, + RECOMMENDED_HARM_BLOCK_THRESHOLD, + ), ), - }, - safety_settings={ - "HARASSMENT": self.entry.options.get( - CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=self.entry.options.get( + CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), ), - "HATE": self.entry.options.get( - CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=self.entry.options.get( + CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), ), - "SEXUAL": self.entry.options.get( - CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD - ), - "DANGEROUS": self.entry.options.get( - CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD - ), - }, + ], tools=tools or None, system_instruction=prompt if supports_system_instruction else None, + automatic_function_calling=AutomaticFunctionCallingConfig( + disable=True, maximum_remote_calls=None + ), ) if not supports_system_instruction: messages = [ - {"role": "user", "parts": prompt}, - {"role": "model", "parts": "Ok"}, + Content(role="user", parts=[Part.from_text(text=prompt)]), + Content(role="model", parts=[Part.from_text(text="Ok")]), *messages, ] - - chat = model.start_chat(history=messages) - chat_request = user_input.text + chat = self._genai_client.aio.chats.create( + model=model_name, history=messages, config=generateContentConfig + ) + chat_request: str | Content = user_input.text # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: - chat_response = await chat.send_message_async(chat_request) - except ( - GoogleAPIError, - ValueError, - genai_types.BlockedPromptException, - genai_types.StopCandidateException, - ) as err: - LOGGER.error("Error sending message: %s %s", type(err), err) + chat_response = await chat.send_message(message=chat_request) - if isinstance( - err, genai_types.StopCandidateException - ) and "finish_reason: SAFETY\n" in str(err): - error = "The message got blocked by your safety settings" - else: - error = ( - f"Sorry, I had a problem talking to Google Generative AI: {err}" + if chat_response.prompt_feedback: + raise HomeAssistantError( + f"The message got blocked due to content violations, reason: {chat_response.prompt_feedback.block_reason_message}" ) + except ( + APIError, + ValueError, + ) as err: + LOGGER.error("Error sending message: %s %s", type(err), err) + error = f"Sorry, I had a problem talking to Google Generative AI: {err}" raise HomeAssistantError(error) from err - LOGGER.debug("Response: %s", chat_response.parts) - if not chat_response.parts: + response_parts = chat_response.candidates[0].content.parts + if not response_parts: raise HomeAssistantError( "Sorry, I had a problem getting a response from Google Generative AI." ) content = " ".join( - [part.text.strip() for part in chat_response.parts if part.text] + [part.text.strip() for part in response_parts if part.text] ) tool_calls = [] - for part in chat_response.parts: + for part in response_parts: if not part.function_call: continue - tool_call = MessageToDict(part.function_call._pb) # noqa: SLF001 - tool_name = tool_call["name"] - tool_args = _escape_decode(tool_call["args"]) + tool_call = part.function_call + tool_name = tool_call.name + tool_args = _escape_decode(tool_call.args) tool_calls.append( llm.ToolInput(tool_name=tool_name, tool_args=tool_args) ) @@ -418,7 +457,7 @@ class GoogleGenerativeAIConversationEntity( response = intent.IntentResponse(language=user_input.language) response.async_set_speech( - " ".join([part.text.strip() for part in chat_response.parts if part.text]) + " ".join([part.text.strip() for part in response_parts if part.text]) ) return conversation.ConversationResult( response=response, conversation_id=chat_log.conversation_id diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json index 7b687b7da6f..cc381532c6f 100644 --- a/homeassistant/components/google_generative_ai_conversation/manifest.json +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -8,5 +8,5 @@ "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["google-generativeai==0.8.2"] + "requirements": ["google-genai==1.1.0"] } diff --git a/requirements_all.txt b/requirements_all.txt index 4ccd6d25719..6b754d8bf59 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1033,7 +1033,7 @@ google-cloud-speech==2.27.0 google-cloud-texttospeech==2.17.2 # homeassistant.components.google_generative_ai_conversation -google-generativeai==0.8.2 +google-genai==1.1.0 # homeassistant.components.nest google-nest-sdm==7.1.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index e42a970c2a0..a7b8120c991 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -883,7 +883,7 @@ google-cloud-speech==2.27.0 google-cloud-texttospeech==2.17.2 # homeassistant.components.google_generative_ai_conversation -google-generativeai==0.8.2 +google-genai==1.1.0 # homeassistant.components.nest google-nest-sdm==7.1.3 diff --git a/tests/components/google_generative_ai_conversation/__init__.py b/tests/components/google_generative_ai_conversation/__init__.py index 8f789d9737e..6e2d37b035b 100644 --- a/tests/components/google_generative_ai_conversation/__init__.py +++ b/tests/components/google_generative_ai_conversation/__init__.py @@ -1 +1,31 @@ """Tests for the Google Generative AI Conversation integration.""" + +from unittest.mock import Mock + +from google.genai.errors import ClientError +import requests + +CLIENT_ERROR_500 = ClientError( + 500, + Mock( + __class__=requests.Response, + json=Mock( + return_value={ + "message": "Internal Server Error", + "status": "internal-error", + } + ), + ), +) +CLIENT_ERROR_API_KEY_INVALID = ClientError( + 400, + Mock( + __class__=requests.Response, + json=Mock( + return_value={ + "message": "'reason': API_KEY_INVALID", + "status": "unauthorized", + } + ), + ), +) diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 28c21a9b791..2bc81b10ce4 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -1,7 +1,6 @@ """Tests helpers.""" -from collections.abc import Generator -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -15,14 +14,7 @@ from tests.common import MockConfigEntry @pytest.fixture -def mock_genai() -> Generator[None]: - """Mock the genai call in async_setup_entry.""" - with patch("google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.get_model"): - yield - - -@pytest.fixture -def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry: +def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: """Mock a config entry.""" entry = MockConfigEntry( domain="google_generative_ai_conversation", @@ -31,18 +23,21 @@ def mock_config_entry(hass: HomeAssistant, mock_genai: None) -> MockConfigEntry: "api_key": "bla", }, ) + entry.runtime_data = Mock() entry.add_to_hass(hass) return entry @pytest.fixture -def mock_config_entry_with_assist( +async def mock_config_entry_with_assist( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> MockConfigEntry: """Mock a config entry with assist.""" - hass.config_entries.async_update_entry( - mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} - ) + with patch("google.genai.models.AsyncModels.get"): + hass.config_entries.async_update_entry( + mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + ) + await hass.async_block_till_done() return mock_config_entry @@ -51,8 +46,11 @@ async def mock_init_component( hass: HomeAssistant, mock_config_entry: ConfigEntry ) -> None: """Initialize integration.""" - assert await async_setup_component(hass, "google_generative_ai_conversation", {}) - await hass.async_block_till_done() + with patch("google.genai.models.AsyncModels.get"): + assert await async_setup_component( + hass, "google_generative_ai_conversation", {} + ) + await hass.async_block_till_done() @pytest.fixture(autouse=True) diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index 1fe02ac2536..7c9bb896bd3 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -6,106 +6,26 @@ tuple( ), dict({ - 'generation_config': dict({ - 'max_output_tokens': 150, - 'temperature': 1.0, - 'top_k': 64, - 'top_p': 0.95, - }), - 'model_name': 'models/gemini-2.0-flash', - 'safety_settings': dict({ - 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', - 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', - 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', - 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', - }), - 'system_instruction': ''' - Current time is 05:00:00. Today's date is 2024-05-24. - You are a voice assistant for Home Assistant. - Answer questions about the world truthfully. - Answer in plain text. Keep it simple and to the point. - Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant. - ''', - 'tools': list([ - function_declarations { - name: "test_tool" - description: "Test function" - parameters { - type_: OBJECT - properties { - key: "param3" - value { - type_: OBJECT - properties { - key: "json" - value { - type_: STRING - } - } - } - } - properties { - key: "param2" - value { - type_: NUMBER - } - } - properties { - key: "param1" - value { - type_: ARRAY - description: "Test parameters" - items { - type_: STRING - } - } - } - } - } - , - ]), - }), - ), - tuple( - '().start_chat', - tuple( - ), - dict({ + 'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description=None, enum=None, format='lower', items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=[Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description=None, enum=None, format=None, items=None, properties=None, required=None), Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description=None, enum=None, format=None, items=None, properties=None, required=None)], max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None), 'history': list([ - dict({ - 'parts': 'Please call the test function', - 'role': 'user', - }), ]), + 'model': 'models/gemini-2.0-flash', }), ), tuple( - '().start_chat().send_message_async', + '().send_message', tuple( - 'Please call the test function', ), dict({ + 'message': 'Please call the test function', }), ), tuple( - '().start_chat().send_message_async', + '().send_message', tuple( - parts { - function_response { - name: "test_tool" - response { - fields { - key: "result" - value { - string_value: "Test response" - } - } - } - } - } - , ), dict({ + 'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None), }), ), ]) @@ -117,75 +37,26 @@ tuple( ), dict({ - 'generation_config': dict({ - 'max_output_tokens': 150, - 'temperature': 1.0, - 'top_k': 64, - 'top_p': 0.95, - }), - 'model_name': 'models/gemini-2.0-flash', - 'safety_settings': dict({ - 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', - 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', - 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', - 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', - }), - 'system_instruction': ''' - Current time is 05:00:00. Today's date is 2024-05-24. - You are a voice assistant for Home Assistant. - Answer questions about the world truthfully. - Answer in plain text. Keep it simple and to the point. - Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant. - ''', - 'tools': list([ - function_declarations { - name: "test_tool" - description: "Test function" - } - , - ]), - }), - ), - tuple( - '().start_chat', - tuple( - ), - dict({ + 'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=), SafetySetting(method=None, category=, threshold=)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=None)], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None), 'history': list([ - dict({ - 'parts': 'Please call the test function', - 'role': 'user', - }), ]), + 'model': 'models/gemini-2.0-flash', }), ), tuple( - '().start_chat().send_message_async', + '().send_message', tuple( - 'Please call the test function', ), dict({ + 'message': 'Please call the test function', }), ), tuple( - '().start_chat().send_message_async', + '().send_message', tuple( - parts { - function_response { - name: "test_tool" - response { - fields { - key: "result" - value { - string_value: "Test response" - } - } - } - } - } - , ), dict({ + 'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None), }), ), ]) diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr index c9e02a6d009..e2d93611ea6 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -6,21 +6,11 @@ tuple( ), dict({ - 'model_name': 'models/gemini-2.0-flash', - }), - ), - tuple( - '().generate_content_async', - tuple( - list([ + 'contents': list([ 'Describe this image from my doorbell camera', - dict({ - 'data': b'image bytes', - 'mime_type': 'image/jpeg', - }), + b'image bytes', ]), - ), - dict({ + 'model': 'models/gemini-2.0-flash', }), ), ]) @@ -32,17 +22,10 @@ tuple( ), dict({ - 'model_name': 'models/gemini-2.0-flash', - }), - ), - tuple( - '().generate_content_async', - tuple( - list([ + 'contents': list([ 'Write an opening speech for a Home Assistant release party', ]), - ), - dict({ + 'model': 'models/gemini-2.0-flash', }), ), ]) diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index ee5291196c3..30c9d6c46e6 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -1,10 +1,9 @@ """Test the Google Generative AI Conversation config flow.""" -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch -from google.api_core.exceptions import ClientError, DeadlineExceeded -from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module import pytest +from requests.exceptions import Timeout from homeassistant import config_entries from homeassistant.components.google_generative_ai_conversation.config_flow import ( @@ -33,6 +32,8 @@ from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID + from tests.common import MockConfigEntry @@ -41,30 +42,37 @@ def mock_models(): """Mock the model list API.""" model_20_flash = Mock( display_name="Gemini 2.0 Flash", - supported_generation_methods=["generateContent"], + supported_actions=["generateContent"], ) model_20_flash.name = "models/gemini-2.0-flash" model_15_flash = Mock( display_name="Gemini 1.5 Flash", - supported_generation_methods=["generateContent"], + supported_actions=["generateContent"], ) model_15_flash.name = "models/gemini-1.5-flash-latest" model_15_pro = Mock( display_name="Gemini 1.5 Pro", - supported_generation_methods=["generateContent"], + supported_actions=["generateContent"], ) model_15_pro.name = "models/gemini-1.5-pro-latest" model_10_pro = Mock( display_name="Gemini 1.0 Pro", - supported_generation_methods=["generateContent"], + supported_actions=["generateContent"], ) model_10_pro.name = "models/gemini-pro" + + async def models_pager(): + yield model_20_flash + yield model_15_flash + yield model_15_pro + yield model_10_pro + with patch( - "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", - return_value=iter([model_20_flash, model_15_flash, model_15_pro, model_10_pro]), + "google.genai.models.AsyncModels.list", + return_value=models_pager(), ): yield @@ -86,7 +94,7 @@ async def test_form(hass: HomeAssistant) -> None: with ( patch( - "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models", + "google.genai.models.AsyncModels.list", ), patch( "homeassistant.components.google_generative_ai_conversation.async_setup_entry", @@ -170,7 +178,11 @@ async def test_options_switching( expected_options, ) -> None: """Test the options form.""" - hass.config_entries.async_update_entry(mock_config_entry, options=current_options) + with patch("google.genai.models.AsyncModels.get"): + hass.config_entries.async_update_entry( + mock_config_entry, options=current_options + ) + await hass.async_block_till_done() options_flow = await hass.config_entries.options.async_init( mock_config_entry.entry_id ) @@ -195,17 +207,15 @@ async def test_options_switching( ("side_effect", "error"), [ ( - ClientError("some error"), + CLIENT_ERROR_500, "cannot_connect", ), ( - DeadlineExceeded("deadline exceeded"), + Timeout("deadline exceeded"), "cannot_connect", ), ( - ClientError( - "invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID") - ), + CLIENT_ERROR_API_KEY_INVALID, "invalid_auth", ), (Exception, "unknown"), @@ -217,12 +227,7 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: DOMAIN, context={"source": config_entries.SOURCE_USER} ) - mock_client = AsyncMock() - mock_client.list_models.side_effect = side_effect - with patch( - "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient", - return_value=mock_client, - ): + with patch("google.genai.models.AsyncModels.list", side_effect=side_effect): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], { @@ -259,7 +264,7 @@ async def test_reauth_flow(hass: HomeAssistant) -> None: with ( patch( - "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models", + "google.genai.models.AsyncModels.list", ), patch( "homeassistant.components.google_generative_ai_conversation.async_setup_entry", diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 9b255666a67..229ee0b323e 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -1,12 +1,10 @@ """Tests for the Google Generative AI Conversation integration conversation platform.""" from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, Mock, patch from freezegun import freeze_time -from google.ai.generativelanguage_v1beta.types.content import FunctionCall -from google.api_core.exceptions import GoogleAPIError -import google.generativeai.types as genai_types +from google.genai.types import FunctionCall import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol @@ -22,6 +20,8 @@ from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import intent, llm +from . import CLIENT_ERROR_500 + from tests.common import MockConfigEntry @@ -51,7 +51,7 @@ async def test_function_call( snapshot: SnapshotAssertion, ) -> None: """Test function calling.""" - agent_id = mock_config_entry_with_assist.entry_id + agent_id = "conversation.google_generative_ai_conversation" context = Context() mock_tool = AsyncMock() @@ -69,12 +69,12 @@ async def test_function_call( mock_get_tools.return_value = [mock_tool] - with patch("google.generativeai.GenerativeModel") as mock_model: + with patch("google.genai.chats.AsyncChats.create") as mock_create: mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - chat_response = MagicMock() - mock_chat.send_message_async.return_value = chat_response - mock_part = MagicMock() + mock_create.return_value.send_message = mock_chat + chat_response = Mock(prompt_feedback=None) + mock_chat.return_value = chat_response + mock_part = Mock() mock_part.text = "" mock_part.function_call = FunctionCall( name="test_tool", @@ -92,7 +92,7 @@ async def test_function_call( return {"result": "Test response"} mock_tool.async_call.side_effect = tool_call - chat_response.parts = [mock_part] + chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] result = await conversation.async_converse( hass, "Please call the test function", @@ -104,20 +104,28 @@ async def test_function_call( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] - mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) - assert mock_tool_call == { + mock_tool_call = mock_create.mock_calls[2][2]["message"] + assert mock_tool_call.model_dump() == { "parts": [ { + "code_execution_result": None, + "executable_code": None, + "file_data": None, + "function_call": None, "function_response": { + "id": None, "name": "test_tool", "response": { "result": "Test response", }, }, + "inline_data": None, + "text": None, + "thought": None, + "video_metadata": None, }, ], - "role": "", + "role": None, } mock_tool.async_call.assert_awaited_once_with( @@ -139,7 +147,7 @@ async def test_function_call( device_id="test_device", ), ) - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot # Test conversating tracing traces = trace.async_get_traces() @@ -170,7 +178,7 @@ async def test_function_call_without_parameters( snapshot: SnapshotAssertion, ) -> None: """Test function calling without parameters.""" - agent_id = mock_config_entry_with_assist.entry_id + agent_id = "conversation.google_generative_ai_conversation" context = Context() mock_tool = AsyncMock() @@ -180,12 +188,12 @@ async def test_function_call_without_parameters( mock_get_tools.return_value = [mock_tool] - with patch("google.generativeai.GenerativeModel") as mock_model: + with patch("google.genai.chats.AsyncChats.create") as mock_create: mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - chat_response = MagicMock() - mock_chat.send_message_async.return_value = chat_response - mock_part = MagicMock() + mock_create.return_value.send_message = mock_chat + chat_response = Mock(prompt_feedback=None) + mock_chat.return_value = chat_response + mock_part = Mock() mock_part.text = "" mock_part.function_call = FunctionCall(name="test_tool", args={}) @@ -197,7 +205,7 @@ async def test_function_call_without_parameters( return {"result": "Test response"} mock_tool.async_call.side_effect = tool_call - chat_response.parts = [mock_part] + chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] result = await conversation.async_converse( hass, "Please call the test function", @@ -209,20 +217,28 @@ async def test_function_call_without_parameters( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] - mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) - assert mock_tool_call == { + mock_tool_call = mock_create.mock_calls[2][2]["message"] + assert mock_tool_call.model_dump() == { "parts": [ { + "code_execution_result": None, + "executable_code": None, + "file_data": None, + "function_call": None, "function_response": { + "id": None, "name": "test_tool", "response": { "result": "Test response", }, }, + "inline_data": None, + "text": None, + "thought": None, + "video_metadata": None, }, ], - "role": "", + "role": None, } mock_tool.async_call.assert_awaited_once_with( @@ -241,7 +257,7 @@ async def test_function_call_without_parameters( device_id="test_device", ), ) - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot @patch( @@ -254,7 +270,7 @@ async def test_function_exception( mock_config_entry_with_assist: MockConfigEntry, ) -> None: """Test exception in function calling.""" - agent_id = mock_config_entry_with_assist.entry_id + agent_id = "conversation.google_generative_ai_conversation" context = Context() mock_tool = AsyncMock() @@ -270,12 +286,12 @@ async def test_function_exception( mock_get_tools.return_value = [mock_tool] - with patch("google.generativeai.GenerativeModel") as mock_model: + with patch("google.genai.chats.AsyncChats.create") as mock_create: mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - chat_response = MagicMock() - mock_chat.send_message_async.return_value = chat_response - mock_part = MagicMock() + mock_create.return_value.send_message = mock_chat + chat_response = Mock(prompt_feedback=None) + mock_chat.return_value = chat_response + mock_part = Mock() mock_part.text = "" mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1}) @@ -287,7 +303,7 @@ async def test_function_exception( raise HomeAssistantError("Test tool exception") mock_tool.async_call.side_effect = tool_call - chat_response.parts = [mock_part] + chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))] result = await conversation.async_converse( hass, "Please call the test function", @@ -299,21 +315,29 @@ async def test_function_exception( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] - mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) - assert mock_tool_call == { + mock_tool_call = mock_create.mock_calls[2][2]["message"] + assert mock_tool_call.model_dump() == { "parts": [ { + "code_execution_result": None, + "executable_code": None, + "file_data": None, + "function_call": None, "function_response": { + "id": None, "name": "test_tool", "response": { "error": "HomeAssistantError", "error_text": "Test tool exception", }, }, + "inline_data": None, + "text": None, + "thought": None, + "video_metadata": None, }, ], - "role": "", + "role": None, } mock_tool.async_call.assert_awaited_once_with( hass, @@ -338,18 +362,22 @@ async def test_error_handling( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test that client errors are caught.""" - with patch("google.generativeai.GenerativeModel") as mock_model: + with patch("google.genai.chats.AsyncChats.create") as mock_create: mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - mock_chat.send_message_async.side_effect = GoogleAPIError("some error") + mock_create.return_value.send_message = mock_chat + mock_chat.side_effect = CLIENT_ERROR_500 result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + hass, + "hello", + None, + Context(), + agent_id="conversation.google_generative_ai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "Sorry, I had a problem talking to Google Generative AI: some error" + "Sorry, I had a problem talking to Google Generative AI: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}" ) @@ -358,20 +386,24 @@ async def test_blocked_response( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test blocked response.""" - with patch("google.generativeai.GenerativeModel") as mock_model: + with patch("google.genai.chats.AsyncChats.create") as mock_create: mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - mock_chat.send_message_async.side_effect = genai_types.StopCandidateException( - "finish_reason: SAFETY\n" - ) + mock_create.return_value.send_message = mock_chat + chat_response = Mock(prompt_feedback=Mock(block_reason_message="SAFETY")) + mock_chat.return_value = chat_response + result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + hass, + "hello", + None, + Context(), + agent_id="conversation.google_generative_ai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "The message got blocked by your safety settings" + "The message got blocked due to content violations, reason: SAFETY" ) @@ -380,14 +412,18 @@ async def test_empty_response( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test empty response.""" - with patch("google.generativeai.GenerativeModel") as mock_model: + with patch("google.genai.chats.AsyncChats.create") as mock_create: mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - chat_response = MagicMock() - mock_chat.send_message_async.return_value = chat_response - chat_response.parts = [] + mock_create.return_value.send_message = mock_chat + chat_response = Mock(prompt_feedback=None) + mock_chat.return_value = chat_response + chat_response.candidates = [Mock(content=Mock(parts=[]))] result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + hass, + "hello", + None, + Context(), + agent_id="conversation.google_generative_ai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -402,17 +438,19 @@ async def test_converse_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test handling ChatLog raising ConverseError.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, - ) + with patch("google.genai.models.AsyncModels.get"): + hass.config_entries.async_update_entry( + mock_config_entry, + options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, + ) + await hass.async_block_till_done() result = await conversation.async_converse( hass, "hello", None, Context(), - agent_id=mock_config_entry.entry_id, + agent_id="conversation.google_generative_ai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -449,31 +487,39 @@ async def test_escape_decode() -> None: @pytest.mark.parametrize( - ("openapi", "protobuf"), + ("openapi", "genai_schema"), [ ( {"type": "string", "enum": ["a", "b", "c"]}, - {"type_": "STRING", "enum": ["a", "b", "c"]}, + {"type": "STRING", "enum": ["a", "b", "c"]}, ), ( {"type": "integer", "enum": [1, 2, 3]}, - {"type_": "STRING", "enum": ["1", "2", "3"]}, + {"type": "STRING", "enum": ["1", "2", "3"]}, + ), + ( + {"anyOf": [{"type": "integer"}, {"type": "number"}]}, + {"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]}, ), - ({"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"type_": "INTEGER"}), ( { - "anyOf": [ - {"anyOf": [{"type": "integer"}, {"type": "number"}]}, - {"anyOf": [{"type": "integer"}, {"type": "number"}]}, + "any_of": [ + {"any_of": [{"type": "integer"}, {"type": "number"}]}, + {"any_of": [{"type": "integer"}, {"type": "number"}]}, + ] + }, + { + "any_of": [ + {"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]}, + {"any_of": [{"type": "INTEGER"}, {"type": "NUMBER"}]}, ] }, - {"type_": "INTEGER"}, ), - ({"type": "string", "format": "lower"}, {"type_": "STRING"}), - ({"type": "boolean", "format": "bool"}, {"type_": "BOOLEAN"}), + ({"type": "string", "format": "lower"}, {"format": "lower", "type": "STRING"}), + ({"type": "boolean", "format": "bool"}, {"format": "bool", "type": "BOOLEAN"}), ( {"type": "number", "format": "percent"}, - {"type_": "NUMBER", "format_": "percent"}, + {"type": "NUMBER", "format": "percent"}, ), ( { @@ -482,25 +528,25 @@ async def test_escape_decode() -> None: "required": [], }, { - "type_": "OBJECT", - "properties": {"var": {"type_": "STRING"}}, + "type": "OBJECT", + "properties": {"var": {"type": "STRING"}}, "required": [], }, ), ( {"type": "object", "additionalProperties": True}, { - "type_": "OBJECT", - "properties": {"json": {"type_": "STRING"}}, + "type": "OBJECT", + "properties": {"json": {"type": "STRING"}}, "required": [], }, ), ( {"type": "array", "items": {"type": "string"}}, - {"type_": "ARRAY", "items": {"type_": "STRING"}}, + {"type": "ARRAY", "items": {"type": "STRING"}}, ), ], ) -async def test_format_schema(openapi, protobuf) -> None: +async def test_format_schema(openapi, genai_schema) -> None: """Test _format_schema.""" - assert _format_schema(openapi) == protobuf + assert _format_schema(openapi) == genai_schema diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 4875323d094..f2e3ac10733 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -1,16 +1,17 @@ """Tests for the Google Generative AI Conversation integration.""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, Mock, patch -from google.api_core.exceptions import ClientError, DeadlineExceeded -from google.rpc.error_details_pb2 import ErrorInfo # pylint: disable=no-name-in-module import pytest +from requests.exceptions import Timeout from syrupy.assertion import SnapshotAssertion from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID + from tests.common import MockConfigEntry @@ -24,12 +25,14 @@ async def test_generate_content_service_without_images( "party for the latest version of Home Assistant!" ) - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_response = MagicMock() - mock_response.text = stubbed_generated_content - mock_model.return_value.generate_content_async = AsyncMock( - return_value=mock_response - ) + with patch( + "google.genai.models.AsyncModels.generate_content", + return_value=Mock( + text=stubbed_generated_content, + prompt_feedback=None, + candidates=[Mock()], + ), + ) as mock_generate: response = await hass.services.async_call( "google_generative_ai_conversation", "generate_content", @@ -41,7 +44,7 @@ async def test_generate_content_service_without_images( assert response == { "text": stubbed_generated_content, } - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot @pytest.mark.usefixtures("mock_init_component") @@ -54,19 +57,21 @@ async def test_generate_content_service_with_image( ) with ( - patch("google.generativeai.GenerativeModel") as mock_model, patch( - "homeassistant.components.google_generative_ai_conversation.Path.read_bytes", + "google.genai.models.AsyncModels.generate_content", + return_value=Mock( + text=stubbed_generated_content, + prompt_feedback=None, + candidates=[Mock()], + ), + ) as mock_generate, + patch( + "homeassistant.components.google_generative_ai_conversation.Image.open", return_value=b"image bytes", ), patch("pathlib.Path.exists", return_value=True), patch.object(hass.config, "is_allowed_path", return_value=True), ): - mock_response = MagicMock() - mock_response.text = stubbed_generated_content - mock_model.return_value.generate_content_async = AsyncMock( - return_value=mock_response - ) response = await hass.services.async_call( "google_generative_ai_conversation", "generate_content", @@ -81,7 +86,7 @@ async def test_generate_content_service_with_image( assert response == { "text": stubbed_generated_content, } - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot @pytest.mark.usefixtures("mock_init_component") @@ -90,20 +95,23 @@ async def test_generate_content_service_error( mock_config_entry: MockConfigEntry, ) -> None: """Test generate content service handles errors.""" - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_model.return_value.generate_content_async = AsyncMock( - side_effect=ClientError("reason") + with ( + patch( + "google.genai.models.AsyncModels.generate_content", + side_effect=CLIENT_ERROR_500, + ), + pytest.raises( + HomeAssistantError, + match="Error generating content: 500 internal-error. {'message': 'Internal Server Error', 'status': 'internal-error'}", + ), + ): + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + {"prompt": "write a story about an epic fail"}, + blocking=True, + return_response=True, ) - with pytest.raises( - HomeAssistantError, match="Error generating content: None reason" - ): - await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - {"prompt": "write a story about an epic fail"}, - blocking=True, - return_response=True, - ) @pytest.mark.usefixtures("mock_init_component") @@ -113,21 +121,22 @@ async def test_generate_content_response_has_empty_parts( ) -> None: """Test generate content service handles response with empty parts.""" with ( - patch("google.generativeai.GenerativeModel") as mock_model, + patch( + "google.genai.models.AsyncModels.generate_content", + return_value=Mock( + prompt_feedback=None, + candidates=[Mock(content=Mock(parts=[]))], + ), + ), + pytest.raises(HomeAssistantError, match="Unknown error generating content"), ): - mock_response = MagicMock() - mock_response.parts = [] - mock_model.return_value.generate_content_async = AsyncMock( - return_value=mock_response + await hass.services.async_call( + "google_generative_ai_conversation", + "generate_content", + {"prompt": "write a story about an epic fail"}, + blocking=True, + return_response=True, ) - with pytest.raises(HomeAssistantError, match="Error generating content"): - await hass.services.async_call( - "google_generative_ai_conversation", - "generate_content", - {"prompt": "write a story about an epic fail"}, - blocking=True, - return_response=True, - ) @pytest.mark.usefixtures("mock_init_component") @@ -211,19 +220,17 @@ async def test_generate_content_service_with_non_image(hass: HomeAssistant) -> N ("side_effect", "state", "reauth"), [ ( - ClientError("some error"), + CLIENT_ERROR_500, ConfigEntryState.SETUP_ERROR, False, ), ( - DeadlineExceeded("deadline exceeded"), + Timeout, ConfigEntryState.SETUP_RETRY, False, ), ( - ClientError( - "invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID") - ), + CLIENT_ERROR_API_KEY_INVALID, ConfigEntryState.SETUP_ERROR, True, ), @@ -235,10 +242,7 @@ async def test_config_entry_error( """Test different configuration entry errors.""" mock_client = AsyncMock() mock_client.get_model.side_effect = side_effect - with patch( - "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient", - return_value=mock_client, - ): + with patch("google.genai.models.AsyncModels.get", side_effect=side_effect): assert not await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() assert mock_config_entry.state == state