From 0972b2951056822c9ce3d834e9b606902fd7bba5 Mon Sep 17 00:00:00 2001 From: tronikos Date: Sun, 26 May 2024 08:44:48 -0700 Subject: [PATCH] Add Google Generative AI reauth flow (#118096) * Add reauth flow * address comments --- .../__init__.py | 30 ++++-- .../config_flow.py | 97 +++++++++++++------ .../strings.json | 15 ++- .../conftest.py | 3 +- .../test_config_flow.py | 61 +++++++++++- .../test_init.py | 53 ++++++---- 6 files changed, 190 insertions(+), 69 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 563d7d341f9..969e6c7a369 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -2,12 +2,11 @@ from __future__ import annotations -from asyncio import timeout from functools import partial import mimetypes from pathlib import Path -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPICallError import google.generativeai as genai import google.generativeai.types as genai_types import voluptuous as vol @@ -20,11 +19,16 @@ from homeassistant.core import ( ServiceResponse, SupportsResponse, ) -from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError +from homeassistant.exceptions import ( + ConfigEntryAuthFailed, + ConfigEntryError, + ConfigEntryNotReady, + HomeAssistantError, +) from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from .const import CONF_PROMPT, DOMAIN, LOGGER +from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" @@ -101,13 +105,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: genai.configure(api_key=entry.data[CONF_API_KEY]) try: - async with timeout(5.0): - next(await hass.async_add_executor_job(partial(genai.list_models)), None) - except (ClientError, TimeoutError) as err: + await hass.async_add_executor_job( + partial( + genai.get_model, + entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + request_options={"timeout": 5.0}, + ) + ) + except (GoogleAPICallError, ValueError) as err: if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": - LOGGER.error("Invalid API key: %s", err) - return False - raise ConfigEntryNotReady(err) from err + raise ConfigEntryAuthFailed(err) from err + if isinstance(err, DeadlineExceeded): + raise ConfigEntryNotReady(err) from err + raise ConfigEntryError(err) from err await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index b559888cc5f..ef700d289c7 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -2,12 +2,13 @@ from __future__ import annotations +from collections.abc import Mapping from functools import partial import logging from types import MappingProxyType from typing import Any -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import ClientError, GoogleAPICallError import google.generativeai as genai import voluptuous as vol @@ -17,7 +18,7 @@ from homeassistant.config_entries import ( ConfigFlowResult, OptionsFlow, ) -from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME from homeassistant.core import HomeAssistant from homeassistant.helpers import llm from homeassistant.helpers.selector import ( @@ -54,7 +55,7 @@ from .const import ( _LOGGER = logging.getLogger(__name__) -STEP_USER_DATA_SCHEMA = vol.Schema( +STEP_API_DATA_SCHEMA = vol.Schema( { vol.Required(CONF_API_KEY): str, } @@ -73,7 +74,11 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. """ genai.configure(api_key=data[CONF_API_KEY]) - await hass.async_add_executor_job(partial(genai.list_models)) + + def get_first_model(): + return next(genai.list_models(request_options={"timeout": 5.0}), None) + + await hass.async_add_executor_job(partial(get_first_model)) class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): @@ -81,36 +86,74 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 + def __init__(self) -> None: + """Initialize a new GoogleGenerativeAIConfigFlow.""" + self.reauth_entry: ConfigEntry | None = None + + async def async_step_api( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle the initial step.""" + errors: dict[str, str] = {} + if user_input is not None: + try: + await validate_input(self.hass, user_input) + except GoogleAPICallError as err: + if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": + errors["base"] = "invalid_auth" + else: + errors["base"] = "cannot_connect" + except Exception: + _LOGGER.exception("Unexpected exception") + errors["base"] = "unknown" + else: + if self.reauth_entry: + return self.async_update_reload_and_abort( + self.reauth_entry, + data=user_input, + ) + return self.async_create_entry( + title="Google Generative AI", + data=user_input, + options=RECOMMENDED_OPTIONS, + ) + return self.async_show_form( + step_id="api", + data_schema=STEP_API_DATA_SCHEMA, + description_placeholders={ + "api_key_url": "https://aistudio.google.com/app/apikey" + }, + errors=errors, + ) + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle the initial step.""" - if user_input is None: - return self.async_show_form( - step_id="user", data_schema=STEP_USER_DATA_SCHEMA - ) + return await self.async_step_api() - errors = {} - - try: - await validate_input(self.hass, user_input) - except ClientError as err: - if err.reason == "API_KEY_INVALID": - errors["base"] = "invalid_auth" - else: - errors["base"] = "cannot_connect" - except Exception: - _LOGGER.exception("Unexpected exception") - errors["base"] = "unknown" - else: - return self.async_create_entry( - title="Google Generative AI", - data=user_input, - options=RECOMMENDED_OPTIONS, - ) + async def async_step_reauth( + self, entry_data: Mapping[str, Any] + ) -> ConfigFlowResult: + """Handle configuration by re-auth.""" + self.reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + return await self.async_step_reauth_confirm() + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Dialog that informs the user that reauth is required.""" + if user_input is not None: + return await self.async_step_api() + assert self.reauth_entry return self.async_show_form( - step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors + step_id="reauth_confirm", + description_placeholders={ + CONF_NAME: self.reauth_entry.title, + CONF_API_KEY: self.reauth_entry.data.get(CONF_API_KEY, ""), + }, ) @staticmethod diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index 4c3ed29500c..9fea4805d38 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -1,17 +1,24 @@ { "config": { "step": { - "user": { + "api": { "data": { - "api_key": "[%key:common::config_flow::data::api_key%]", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" - } + "api_key": "[%key:common::config_flow::data::api_key%]" + }, + "description": "Get your API key from [here]({api_key_url})." + }, + "reauth_confirm": { + "title": "[%key:common::config_flow::title::reauth%]", + "description": "Your current API key: {api_key} is no longer valid. Please enter a new valid API key." } }, "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "unknown": "[%key:common::config_flow::error::unknown%]" + }, + "abort": { + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } }, "options": { diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 8ab8020428e..7c4aef75776 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -17,8 +17,7 @@ from tests.common import MockConfigEntry def mock_genai(): """Mock the genai call in async_setup_entry.""" with patch( - "homeassistant.components.google_generative_ai_conversation.genai.list_models", - return_value=iter([]), + "homeassistant.components.google_generative_ai_conversation.genai.get_model" ): yield diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 55350325eee..805fb9c3c74 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import ClientError, DeadlineExceeded from google.rpc.error_details_pb2 import ErrorInfo import pytest @@ -69,7 +69,7 @@ async def test_form(hass: HomeAssistant) -> None: DOMAIN, context={"source": config_entries.SOURCE_USER} ) assert result["type"] is FlowResultType.FORM - assert result["errors"] is None + assert not result["errors"] with ( patch( @@ -186,13 +186,16 @@ async def test_options_switching( ("side_effect", "error"), [ ( - ClientError(message="some error"), + ClientError("some error"), + "cannot_connect", + ), + ( + DeadlineExceeded("deadline exceeded"), "cannot_connect", ), ( ClientError( - message="invalid api key", - error_info=ErrorInfo(reason="API_KEY_INVALID"), + "invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID") ), "invalid_auth", ), @@ -218,3 +221,51 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: assert result2["type"] is FlowResultType.FORM assert result2["errors"] == {"base": error} + + +async def test_reauth_flow(hass: HomeAssistant) -> None: + """Test the reauth flow.""" + hass.config.components.add("google_generative_ai_conversation") + mock_config_entry = MockConfigEntry( + domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini" + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + result = flows[0] + assert result["step_id"] == "reauth_confirm" + assert result["context"]["source"] == "reauth" + assert result["context"]["title_placeholders"] == {"name": "Gemini"} + + result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "api" + assert "api_key" in result["data_schema"].schema + assert not result["errors"] + + with ( + patch( + "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", + ), + patch( + "homeassistant.components.google_generative_ai_conversation.async_setup_entry", + return_value=True, + ) as mock_setup_entry, + patch( + "homeassistant.components.google_generative_ai_conversation.async_unload_entry", + return_value=True, + ) as mock_unload_entry, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"api_key": "1234"} + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert hass.config_entries.async_entries(DOMAIN)[0].data == {"api_key": "1234"} + assert len(mock_unload_entry.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index a6a5fdf0b0e..44096e98469 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -2,7 +2,8 @@ from unittest.mock import AsyncMock, MagicMock, patch -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import ClientError, DeadlineExceeded +from google.rpc.error_details_pb2 import ErrorInfo import pytest from syrupy.assertion import SnapshotAssertion @@ -220,29 +221,39 @@ async def test_generate_content_service_with_non_image( ) -async def test_config_entry_not_ready( - hass: HomeAssistant, mock_config_entry: MockConfigEntry +@pytest.mark.parametrize( + ("side_effect", "state", "reauth"), + [ + ( + ClientError("some error"), + ConfigEntryState.SETUP_ERROR, + False, + ), + ( + DeadlineExceeded("deadline exceeded"), + ConfigEntryState.SETUP_RETRY, + False, + ), + ( + ClientError( + "invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID") + ), + ConfigEntryState.SETUP_ERROR, + True, + ), + ], +) +async def test_config_entry_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, side_effect, state, reauth ) -> None: - """Test configuration entry not ready.""" + """Test different configuration entry errors.""" with patch( - "homeassistant.components.google_generative_ai_conversation.genai.list_models", - side_effect=ClientError("error"), + "homeassistant.components.google_generative_ai_conversation.genai.get_model", + side_effect=side_effect, ): mock_config_entry.add_to_hass(hass) await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() - assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY - - -async def test_config_entry_setup_error( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test configuration entry setup error.""" - with patch( - "homeassistant.components.google_generative_ai_conversation.genai.list_models", - side_effect=ClientError("error", error_info="API_KEY_INVALID"), - ): - mock_config_entry.add_to_hass(hass) - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + assert mock_config_entry.state is state + mock_config_entry.async_get_active_flows(hass, {"reauth"}) + assert any(mock_config_entry.async_get_active_flows(hass, {"reauth"})) is reauth