Fix Google AI not using correct config options after subentries migration (#147493)

This commit is contained in:
tronikos 2025-06-25 02:25:01 -07:00 committed by GitHub
parent 0bbb168862
commit f897a728f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 44 deletions

View File

@ -35,7 +35,6 @@ from homeassistant.helpers.issue_registry import IssueSeverity, async_create_iss
from homeassistant.helpers.typing import ConfigType
from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
@ -190,7 +189,7 @@ async def async_setup_entry(
client = await hass.async_add_executor_job(_init_client)
await client.aio.models.get(
model=entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
model=RECOMMENDED_CHAT_MODEL,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)
except (APIError, Timeout) as err:

View File

@ -337,7 +337,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
tools = tools or []
tools.append(Tool(google_search=GoogleSearch()))
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
model_name = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# Avoid INVALID_ARGUMENT Developer instruction is not enabled for <model>
supports_system_instruction = (
"gemma" not in model_name
@ -389,47 +389,13 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
if tool_results:
messages.append(_create_google_tool_response_content(tool_results))
generateContentConfig = GenerateContentConfig(
temperature=self.entry.options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
),
top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
max_output_tokens=self.entry.options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=self.entry.options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=self.entry.options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=self.entry.options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=self.entry.options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
],
tools=tools or None,
system_instruction=prompt if supports_system_instruction else None,
automatic_function_calling=AutomaticFunctionCallingConfig(
disable=True, maximum_remote_calls=None
),
generateContentConfig = self.create_generate_content_config()
generateContentConfig.tools = tools or None
generateContentConfig.system_instruction = (
prompt if supports_system_instruction else None
)
generateContentConfig.automatic_function_calling = (
AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None)
)
if not supports_system_instruction:
@ -472,3 +438,40 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
if not chat_log.unresponded_tool_results:
break
def create_generate_content_config(self) -> GenerateContentConfig:
"""Create the GenerateContentConfig for the LLM."""
options = self.subentry.data
return GenerateContentConfig(
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
top_k=options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=options.get(
CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=options.get(
CONF_HARASSMENT_BLOCK_THRESHOLD,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=options.get(
CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=options.get(
CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD
),
),
],
)