From 0c245f1976141a46df0b135411d4c0dd89241cf9 Mon Sep 17 00:00:00 2001 From: tronikos Date: Mon, 27 May 2024 20:49:16 -0700 Subject: [PATCH] Fix freezing on HA startup when there are multiple Google Generative AI config entries (#118282) * Fix freezing on HA startup when there are multiple Google Generative AI config entries * Add timeout to list_models --- .../google_generative_ai_conversation/__init__.py | 14 +++++++------- .../config_flow.py | 12 ++++++------ .../google_generative_ai_conversation/conftest.py | 11 +++-------- .../test_config_flow.py | 12 +++++++----- .../test_conversation.py | 7 +------ .../google_generative_ai_conversation/test_init.py | 13 +++++++------ 6 files changed, 31 insertions(+), 38 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 969e6c7a369..8a1197987e1 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -2,10 +2,11 @@ from __future__ import annotations -from functools import partial import mimetypes from pathlib import Path +from google.ai import generativelanguage_v1beta +from google.api_core.client_options import ClientOptions from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPICallError import google.generativeai as genai import google.generativeai.types as genai_types @@ -105,12 +106,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: genai.configure(api_key=entry.data[CONF_API_KEY]) try: - await hass.async_add_executor_job( - partial( - genai.get_model, - entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), - request_options={"timeout": 5.0}, - ) + client = generativelanguage_v1beta.ModelServiceAsyncClient( + client_options=ClientOptions(api_key=entry.data[CONF_API_KEY]) + ) + await client.get_model( + name=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), timeout=5.0 ) except (GoogleAPICallError, ValueError) as err: if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index b373239665d..543deb926a0 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -8,6 +8,8 @@ import logging from types import MappingProxyType from typing import Any +from google.ai import generativelanguage_v1beta +from google.api_core.client_options import ClientOptions from google.api_core.exceptions import ClientError, GoogleAPICallError import google.generativeai as genai import voluptuous as vol @@ -72,12 +74,10 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. """ - genai.configure(api_key=data[CONF_API_KEY]) - - def get_first_model(): - return next(genai.list_models(request_options={"timeout": 5.0}), None) - - await hass.async_add_executor_job(partial(get_first_model)) + client = generativelanguage_v1beta.ModelServiceAsyncClient( + client_options=ClientOptions(api_key=data[CONF_API_KEY]) + ) + await client.list_models(timeout=5.0) class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 7c4aef75776..1761516e4f5 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -16,9 +16,7 @@ from tests.common import MockConfigEntry @pytest.fixture def mock_genai(): """Mock the genai call in async_setup_entry.""" - with patch( - "homeassistant.components.google_generative_ai_conversation.genai.get_model" - ): + with patch("google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.get_model"): yield @@ -48,11 +46,8 @@ def mock_config_entry_with_assist(hass, mock_config_entry): @pytest.fixture async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry): """Initialize integration.""" - with patch("google.generativeai.get_model"): - assert await async_setup_component( - hass, "google_generative_ai_conversation", {} - ) - await hass.async_block_till_done() + assert await async_setup_component(hass, "google_generative_ai_conversation", {}) + await hass.async_block_till_done() @pytest.fixture(autouse=True) diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 77da95506fa..41b1dbeb32e 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -1,6 +1,6 @@ """Test the Google Generative AI Conversation config flow.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from google.api_core.exceptions import ClientError, DeadlineExceeded from google.rpc.error_details_pb2 import ErrorInfo @@ -74,7 +74,7 @@ async def test_form(hass: HomeAssistant) -> None: with ( patch( - "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", + "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models", ), patch( "homeassistant.components.google_generative_ai_conversation.async_setup_entry", @@ -205,9 +205,11 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: DOMAIN, context={"source": config_entries.SOURCE_USER} ) + mock_client = AsyncMock() + mock_client.list_models.side_effect = side_effect with patch( - "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", - side_effect=side_effect, + "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient", + return_value=mock_client, ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -245,7 +247,7 @@ async def test_reauth_flow(hass: HomeAssistant) -> None: with ( patch( - "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", + "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient.list_models", ), patch( "homeassistant.components.google_generative_ai_conversation.async_setup_entry", diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 284bd904b44..08e6e5c12fc 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -538,12 +538,7 @@ async def test_template_error( "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", }, ) - with ( - patch( - "google.generativeai.get_model", - ), - patch("google.generativeai.GenerativeModel"), - ): + with patch("google.generativeai.GenerativeModel"): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() result = await conversation.async_converse( diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index 44096e98469..a3926338b20 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -247,13 +247,14 @@ async def test_config_entry_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry, side_effect, state, reauth ) -> None: """Test different configuration entry errors.""" + mock_client = AsyncMock() + mock_client.get_model.side_effect = side_effect with patch( - "homeassistant.components.google_generative_ai_conversation.genai.get_model", - side_effect=side_effect, + "google.ai.generativelanguage_v1beta.ModelServiceAsyncClient", + return_value=mock_client, ): - mock_config_entry.add_to_hass(hass) - await hass.config_entries.async_setup(mock_config_entry.entry_id) + assert not await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() - assert mock_config_entry.state is state + assert mock_config_entry.state == state mock_config_entry.async_get_active_flows(hass, {"reauth"}) - assert any(mock_config_entry.async_get_active_flows(hass, {"reauth"})) is reauth + assert any(mock_config_entry.async_get_active_flows(hass, {"reauth"})) == reauth