Update Google Generative AI to allow multiple LLM APIs (#143191)

This commit is contained in:
Allen Porter 2025-04-19 02:44:12 -07:00 committed by GitHub
parent 61e4be4456
commit 4483025856
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 17 deletions

View File

@ -183,10 +183,10 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
if user_input is not None: if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none": if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API) user_input.pop(CONF_LLM_HASS_API, None)
if not ( if not (
user_input.get(CONF_LLM_HASS_API, "none") != "none" user_input.get(CONF_LLM_HASS_API)
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
): ):
# Don't allow to save options that enable the Google Seearch tool with an Assist API # Don't allow to save options that enable the Google Seearch tool with an Assist API
@ -213,18 +213,16 @@ async def google_generative_ai_config_option_schema(
) -> dict: ) -> dict:
"""Return a schema for Google Generative AI completion options.""" """Return a schema for Google Generative AI completion options."""
hass_apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
hass_apis.extend(
SelectOptionDict( SelectOptionDict(
label=api.name, label=api.name,
value=api.id, value=api.id,
) )
for api in llm.async_get_apis(hass) for api in llm.async_get_apis(hass)
) ]
if (suggested_llm_apis := options.get(CONF_LLM_HASS_API)) and isinstance(
suggested_llm_apis, str
):
suggested_llm_apis = [suggested_llm_apis]
schema = { schema = {
vol.Optional( vol.Optional(
@ -237,9 +235,8 @@ async def google_generative_ai_config_option_schema(
): TemplateSelector(), ): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_LLM_HASS_API, CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, description={"suggested_value": suggested_llm_apis},
default="none", ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required( vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool, ): bool,

View File

@ -125,7 +125,6 @@ def will_options_be_rendered_again(current_options, new_options) -> bool:
( (
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none",
CONF_PROMPT: "bla", CONF_PROMPT: "bla",
}, },
{ {
@ -162,12 +161,12 @@ def will_options_be_rendered_again(current_options, new_options) -> bool:
}, },
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist", CONF_LLM_HASS_API: ["assist"],
CONF_PROMPT: "", CONF_PROMPT: "",
}, },
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist", CONF_LLM_HASS_API: ["assist"],
CONF_PROMPT: "", CONF_PROMPT: "",
}, },
None, None,
@ -235,7 +234,7 @@ def will_options_be_rendered_again(current_options, new_options) -> bool:
{ {
CONF_RECOMMENDED: False, CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate", CONF_PROMPT: "Speak like a pirate",
CONF_LLM_HASS_API: "assist", CONF_LLM_HASS_API: ["assist"],
CONF_TEMPERATURE: 0.3, CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_P: RECOMMENDED_TOP_P,
@ -263,6 +262,24 @@ def will_options_be_rendered_again(current_options, new_options) -> bool:
}, },
{CONF_USE_GOOGLE_SEARCH_TOOL: "invalid_google_search_option"}, {CONF_USE_GOOGLE_SEARCH_TOOL: "invalid_google_search_option"},
), ),
(
{
CONF_RECOMMENDED: True,
CONF_PROMPT: "",
CONF_LLM_HASS_API: "assist",
},
{
CONF_RECOMMENDED: True,
CONF_PROMPT: "",
CONF_LLM_HASS_API: ["assist"],
},
{
CONF_RECOMMENDED: True,
CONF_PROMPT: "",
CONF_LLM_HASS_API: ["assist"],
},
None,
),
], ],
) )
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")