mirror of
https://github.com/home-assistant/core.git
synced 2025-11-11 12:00:52 +00:00
Increase AI Task default tokens for Google Gemini
This commit is contained in:
@@ -15,7 +15,13 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
from homeassistant.util.json import json_loads
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
from .const import CONF_CHAT_MODEL, CONF_RECOMMENDED, LOGGER, RECOMMENDED_IMAGE_MODEL
|
from .const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
|
LOGGER,
|
||||||
|
RECOMMENDED_A_TASK_MAX_TOKENS,
|
||||||
|
RECOMMENDED_IMAGE_MODEL,
|
||||||
|
)
|
||||||
from .entity import (
|
from .entity import (
|
||||||
ERROR_GETTING_RESPONSE,
|
ERROR_GETTING_RESPONSE,
|
||||||
GoogleGenerativeAILLMBaseEntity,
|
GoogleGenerativeAILLMBaseEntity,
|
||||||
@@ -73,7 +79,9 @@ class GoogleGenerativeAITaskEntity(
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> ai_task.GenDataTaskResult:
|
) -> ai_task.GenDataTaskResult:
|
||||||
"""Handle a generate data task."""
|
"""Handle a generate data task."""
|
||||||
await self._async_handle_chat_log(chat_log, task.structure)
|
await self._async_handle_chat_log(
|
||||||
|
chat_log, task.structure, default_max_tokens=RECOMMENDED_A_TASK_MAX_TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ CONF_TOP_K = "top_k"
|
|||||||
RECOMMENDED_TOP_K = 64
|
RECOMMENDED_TOP_K = 64
|
||||||
CONF_MAX_TOKENS = "max_tokens"
|
CONF_MAX_TOKENS = "max_tokens"
|
||||||
RECOMMENDED_MAX_TOKENS = 3000
|
RECOMMENDED_MAX_TOKENS = 3000
|
||||||
|
# Input 5000, output 19400 = 0.05 USD
|
||||||
|
RECOMMENDED_A_TASK_MAX_TOKENS = 19400
|
||||||
CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
|
CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
|
||||||
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
||||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||||
|
|||||||
@@ -472,6 +472,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
self,
|
self,
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
structure: vol.Schema | None = None,
|
structure: vol.Schema | None = None,
|
||||||
|
default_max_tokens: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
@@ -618,7 +619,9 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
if not chat_log.unresponded_tool_results:
|
if not chat_log.unresponded_tool_results:
|
||||||
break
|
break
|
||||||
|
|
||||||
def create_generate_content_config(self) -> GenerateContentConfig:
|
def create_generate_content_config(
|
||||||
|
self, default_max_tokens: int | None = None
|
||||||
|
) -> GenerateContentConfig:
|
||||||
"""Create the GenerateContentConfig for the LLM."""
|
"""Create the GenerateContentConfig for the LLM."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
model = options.get(CONF_CHAT_MODEL, self.default_model)
|
model = options.get(CONF_CHAT_MODEL, self.default_model)
|
||||||
@@ -632,7 +635,9 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
top_k=options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
|
top_k=options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
|
||||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
max_output_tokens=options.get(
|
||||||
|
CONF_MAX_TOKENS, default_max_tokens or RECOMMENDED_MAX_TOKENS
|
||||||
|
),
|
||||||
safety_settings=[
|
safety_settings=[
|
||||||
SafetySetting(
|
SafetySetting(
|
||||||
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||||
|
|||||||
Reference in New Issue
Block a user