Add Google Generative AI reauth flow (#118096)

* Add reauth flow

* address comments
This commit is contained in:
tronikos 2024-05-26 08:44:48 -07:00 committed by GitHub
parent b85cf36a68
commit 0972b29510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 190 additions and 69 deletions

View File

@ -2,12 +2,11 @@
from __future__ import annotations from __future__ import annotations
from asyncio import timeout
from functools import partial from functools import partial
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError, DeadlineExceeded, GoogleAPICallError
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.types as genai_types import google.generativeai.types as genai_types
import voluptuous as vol import voluptuous as vol
@ -20,11 +19,16 @@ from homeassistant.core import (
ServiceResponse, ServiceResponse,
SupportsResponse, SupportsResponse,
) )
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
ConfigEntryNotReady,
HomeAssistantError,
)
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import CONF_PROMPT, DOMAIN, LOGGER from .const import CONF_CHAT_MODEL, CONF_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename" CONF_IMAGE_FILENAME = "image_filename"
@ -101,13 +105,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
genai.configure(api_key=entry.data[CONF_API_KEY]) genai.configure(api_key=entry.data[CONF_API_KEY])
try: try:
async with timeout(5.0): await hass.async_add_executor_job(
next(await hass.async_add_executor_job(partial(genai.list_models)), None) partial(
except (ClientError, TimeoutError) as err: genai.get_model,
entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
request_options={"timeout": 5.0},
)
)
except (GoogleAPICallError, ValueError) as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID": if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
LOGGER.error("Invalid API key: %s", err) raise ConfigEntryAuthFailed(err) from err
return False if isinstance(err, DeadlineExceeded):
raise ConfigEntryNotReady(err) from err raise ConfigEntryNotReady(err) from err
raise ConfigEntryError(err) from err
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

View File

@ -2,12 +2,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
from functools import partial from functools import partial
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError, GoogleAPICallError
import google.generativeai as genai import google.generativeai as genai
import voluptuous as vol import voluptuous as vol
@ -17,7 +18,7 @@ from homeassistant.config_entries import (
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, OptionsFlow,
) )
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
@ -54,7 +55,7 @@ from .const import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STEP_USER_DATA_SCHEMA = vol.Schema( STEP_API_DATA_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_API_KEY): str, vol.Required(CONF_API_KEY): str,
} }
@ -73,7 +74,11 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
""" """
genai.configure(api_key=data[CONF_API_KEY]) genai.configure(api_key=data[CONF_API_KEY])
await hass.async_add_executor_job(partial(genai.list_models))
def get_first_model():
return next(genai.list_models(request_options={"timeout": 5.0}), None)
await hass.async_add_executor_job(partial(get_first_model))
class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
@ -81,36 +86,74 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
def __init__(self) -> None:
"""Initialize a new GoogleGenerativeAIConfigFlow."""
self.reauth_entry: ConfigEntry | None = None
async def async_step_api(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
errors: dict[str, str] = {}
if user_input is not None:
try:
await validate_input(self.hass, user_input)
except GoogleAPICallError as err:
if isinstance(err, ClientError) and err.reason == "API_KEY_INVALID":
errors["base"] = "invalid_auth"
else:
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
if self.reauth_entry:
return self.async_update_reload_and_abort(
self.reauth_entry,
data=user_input,
)
return self.async_create_entry(
title="Google Generative AI",
data=user_input,
options=RECOMMENDED_OPTIONS,
)
return self.async_show_form(
step_id="api",
data_schema=STEP_API_DATA_SCHEMA,
description_placeholders={
"api_key_url": "https://aistudio.google.com/app/apikey"
},
errors=errors,
)
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
if user_input is None: return await self.async_step_api()
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
)
errors = {} async def async_step_reauth(
self, entry_data: Mapping[str, Any]
try: ) -> ConfigFlowResult:
await validate_input(self.hass, user_input) """Handle configuration by re-auth."""
except ClientError as err: self.reauth_entry = self.hass.config_entries.async_get_entry(
if err.reason == "API_KEY_INVALID": self.context["entry_id"]
errors["base"] = "invalid_auth" )
else: return await self.async_step_reauth_confirm()
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
return self.async_create_entry(
title="Google Generative AI",
data=user_input,
options=RECOMMENDED_OPTIONS,
)
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Dialog that informs the user that reauth is required."""
if user_input is not None:
return await self.async_step_api()
assert self.reauth_entry
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors step_id="reauth_confirm",
description_placeholders={
CONF_NAME: self.reauth_entry.title,
CONF_API_KEY: self.reauth_entry.data.get(CONF_API_KEY, ""),
},
) )
@staticmethod @staticmethod

View File

@ -1,17 +1,24 @@
{ {
"config": { "config": {
"step": { "step": {
"user": { "api": {
"data": { "data": {
"api_key": "[%key:common::config_flow::data::api_key%]", "api_key": "[%key:common::config_flow::data::api_key%]"
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" },
} "description": "Get your API key from [here]({api_key_url})."
},
"reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]",
"description": "Your current API key: {api_key} is no longer valid. Please enter a new valid API key."
} }
}, },
"error": { "error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]"
},
"abort": {
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
}, },
"options": { "options": {

View File

@ -17,8 +17,7 @@ from tests.common import MockConfigEntry
def mock_genai(): def mock_genai():
"""Mock the genai call in async_setup_entry.""" """Mock the genai call in async_setup_entry."""
with patch( with patch(
"homeassistant.components.google_generative_ai_conversation.genai.list_models", "homeassistant.components.google_generative_ai_conversation.genai.get_model"
return_value=iter([]),
): ):
yield yield

View File

@ -2,7 +2,7 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo from google.rpc.error_details_pb2 import ErrorInfo
import pytest import pytest
@ -69,7 +69,7 @@ async def test_form(hass: HomeAssistant) -> None:
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["errors"] is None assert not result["errors"]
with ( with (
patch( patch(
@ -186,13 +186,16 @@ async def test_options_switching(
("side_effect", "error"), ("side_effect", "error"),
[ [
( (
ClientError(message="some error"), ClientError("some error"),
"cannot_connect",
),
(
DeadlineExceeded("deadline exceeded"),
"cannot_connect", "cannot_connect",
), ),
( (
ClientError( ClientError(
message="invalid api key", "invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
error_info=ErrorInfo(reason="API_KEY_INVALID"),
), ),
"invalid_auth", "invalid_auth",
), ),
@ -218,3 +221,51 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": error} assert result2["errors"] == {"base": error}
async def test_reauth_flow(hass: HomeAssistant) -> None:
"""Test the reauth flow."""
hass.config.components.add("google_generative_ai_conversation")
mock_config_entry = MockConfigEntry(
domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini"
)
mock_config_entry.add_to_hass(hass)
mock_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
assert result["context"]["source"] == "reauth"
assert result["context"]["title_placeholders"] == {"name": "Gemini"}
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "api"
assert "api_key" in result["data_schema"].schema
assert not result["errors"]
with (
patch(
"homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models",
),
patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
return_value=True,
) as mock_setup_entry,
patch(
"homeassistant.components.google_generative_ai_conversation.async_unload_entry",
return_value=True,
) as mock_unload_entry,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"api_key": "1234"}
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert hass.config_entries.async_entries(DOMAIN)[0].data == {"api_key": "1234"}
assert len(mock_unload_entry.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1

View File

@ -2,7 +2,8 @@
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError, DeadlineExceeded
from google.rpc.error_details_pb2 import ErrorInfo
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -220,29 +221,39 @@ async def test_generate_content_service_with_non_image(
) )
async def test_config_entry_not_ready( @pytest.mark.parametrize(
hass: HomeAssistant, mock_config_entry: MockConfigEntry ("side_effect", "state", "reauth"),
[
(
ClientError("some error"),
ConfigEntryState.SETUP_ERROR,
False,
),
(
DeadlineExceeded("deadline exceeded"),
ConfigEntryState.SETUP_RETRY,
False,
),
(
ClientError(
"invalid api key", error_info=ErrorInfo(reason="API_KEY_INVALID")
),
ConfigEntryState.SETUP_ERROR,
True,
),
],
)
async def test_config_entry_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, side_effect, state, reauth
) -> None: ) -> None:
"""Test configuration entry not ready.""" """Test different configuration entry errors."""
with patch( with patch(
"homeassistant.components.google_generative_ai_conversation.genai.list_models", "homeassistant.components.google_generative_ai_conversation.genai.get_model",
side_effect=ClientError("error"), side_effect=side_effect,
): ):
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY assert mock_config_entry.state is state
mock_config_entry.async_get_active_flows(hass, {"reauth"})
assert any(mock_config_entry.async_get_active_flows(hass, {"reauth"})) is reauth
async def test_config_entry_setup_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test configuration entry setup error."""
with patch(
"homeassistant.components.google_generative_ai_conversation.genai.list_models",
side_effect=ClientError("error", error_info="API_KEY_INVALID"),
):
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR