From 620487fe7507bf6bf20f2a013fedee65476939d1 Mon Sep 17 00:00:00 2001 From: tronikos Date: Fri, 24 May 2024 18:48:39 -0700 Subject: [PATCH] Add Google Generative AI safety settings (#117679) * Add safety settings * snapshot-update * DROPDOWN * fix test * rename const * Update const.py * Update strings.json --- .../config_flow.py | 55 +++++++++++++++++++ .../const.py | 5 ++ .../conversation.py | 19 +++++++ .../strings.json | 6 +- .../snapshots/test_conversation.ambr | 24 ++++++++ .../test_config_flow.py | 9 +++ 6 files changed, 117 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index 3845d7f4e92..4fff5bff655 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -32,15 +32,20 @@ from homeassistant.helpers.selector import ( from .const import ( CONF_CHAT_MODEL, + CONF_DANGEROUS_BLOCK_THRESHOLD, + CONF_HARASSMENT_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD, CONF_MAX_TOKENS, CONF_PROMPT, CONF_RECOMMENDED, + CONF_SEXUAL_BLOCK_THRESHOLD, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, DEFAULT_PROMPT, DOMAIN, RECOMMENDED_CHAT_MODEL, + RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_MAX_TOKENS, RECOMMENDED_TEMPERATURE, RECOMMENDED_TOP_K, @@ -207,6 +212,30 @@ async def google_generative_ai_config_option_schema( ) ] + harm_block_thresholds: list[SelectOptionDict] = [ + SelectOptionDict( + label="Block none", + value="BLOCK_NONE", + ), + SelectOptionDict( + label="Block few", + value="BLOCK_ONLY_HIGH", + ), + SelectOptionDict( + label="Block some", + value="BLOCK_MEDIUM_AND_ABOVE", + ), + SelectOptionDict( + label="Block most", + value="BLOCK_LOW_AND_ABOVE", + ), + ] + harm_block_thresholds_selector = SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, options=harm_block_thresholds + ) + ) + schema.update( { vol.Optional( @@ -236,6 +265,32 @@ async def google_generative_ai_config_option_schema( description={"suggested_value": options.get(CONF_MAX_TOKENS)}, default=RECOMMENDED_MAX_TOKENS, ): int, + vol.Optional( + CONF_HARASSMENT_BLOCK_THRESHOLD, + description={ + "suggested_value": options.get(CONF_HARASSMENT_BLOCK_THRESHOLD) + }, + default=RECOMMENDED_HARM_BLOCK_THRESHOLD, + ): harm_block_thresholds_selector, + vol.Optional( + CONF_HATE_BLOCK_THRESHOLD, + description={"suggested_value": options.get(CONF_HATE_BLOCK_THRESHOLD)}, + default=RECOMMENDED_HARM_BLOCK_THRESHOLD, + ): harm_block_thresholds_selector, + vol.Optional( + CONF_SEXUAL_BLOCK_THRESHOLD, + description={ + "suggested_value": options.get(CONF_SEXUAL_BLOCK_THRESHOLD) + }, + default=RECOMMENDED_HARM_BLOCK_THRESHOLD, + ): harm_block_thresholds_selector, + vol.Optional( + CONF_DANGEROUS_BLOCK_THRESHOLD, + description={ + "suggested_value": options.get(CONF_DANGEROUS_BLOCK_THRESHOLD) + }, + default=RECOMMENDED_HARM_BLOCK_THRESHOLD, + ): harm_block_thresholds_selector, } ) return schema diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 9a16a31abd7..549883d4fb9 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -18,3 +18,8 @@ CONF_TOP_K = "top_k" RECOMMENDED_TOP_K = 64 CONF_MAX_TOKENS = "max_tokens" RECOMMENDED_MAX_TOKENS = 150 +CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold" +CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold" +CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold" +CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold" +RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE" diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 2bc79ac8dde..f743c991759 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -22,8 +22,12 @@ from homeassistant.util import ulid from .const import ( CONF_CHAT_MODEL, + CONF_DANGEROUS_BLOCK_THRESHOLD, + CONF_HARASSMENT_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD, CONF_MAX_TOKENS, CONF_PROMPT, + CONF_SEXUAL_BLOCK_THRESHOLD, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, @@ -31,6 +35,7 @@ from .const import ( DOMAIN, LOGGER, RECOMMENDED_CHAT_MODEL, + RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_MAX_TOKENS, RECOMMENDED_TEMPERATURE, RECOMMENDED_TOP_K, @@ -168,6 +173,20 @@ class GoogleGenerativeAIConversationEntity( CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS ), }, + safety_settings={ + "HARASSMENT": self.entry.options.get( + CONF_HARASSMENT_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + "HATE": self.entry.options.get( + CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + "SEXUAL": self.entry.options.get( + CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + "DANGEROUS": self.entry.options.get( + CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + }, tools=tools or None, ) diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index f35561a6aa6..4c3ed29500c 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -25,7 +25,11 @@ "top_p": "Top P", "top_k": "Top K", "max_tokens": "Maximum tokens to return in response", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes", + "hate_block_threshold": "Content that is rude, disrespectful, or profane", + "sexual_block_threshold": "Contains references to sexual acts or other lewd content", + "dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts" }, "data_description": { "prompt": "Instruct how the LLM should respond. This can be a template." diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index fe44c6a1608..ebc918bbf31 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -13,6 +13,12 @@ 'top_p': 0.95, }), 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', + 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', + 'HATE': 'BLOCK_LOW_AND_ABOVE', + 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + }), 'tools': None, }), ), @@ -57,6 +63,12 @@ 'top_p': 0.95, }), 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', + 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', + 'HATE': 'BLOCK_LOW_AND_ABOVE', + 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + }), 'tools': None, }), ), @@ -101,6 +113,12 @@ 'top_p': 0.95, }), 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', + 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', + 'HATE': 'BLOCK_LOW_AND_ABOVE', + 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + }), 'tools': None, }), ), @@ -148,6 +166,12 @@ 'top_p': 0.95, }), 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', + 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', + 'HATE': 'BLOCK_LOW_AND_ABOVE', + 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + }), 'tools': None, }), ), diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index ba21407343e..6426386243c 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -9,14 +9,19 @@ import pytest from homeassistant import config_entries from homeassistant.components.google_generative_ai_conversation.const import ( CONF_CHAT_MODEL, + CONF_DANGEROUS_BLOCK_THRESHOLD, + CONF_HARASSMENT_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD, CONF_MAX_TOKENS, CONF_PROMPT, CONF_RECOMMENDED, + CONF_SEXUAL_BLOCK_THRESHOLD, CONF_TEMPERATURE, CONF_TOP_K, CONF_TOP_P, DOMAIN, RECOMMENDED_CHAT_MODEL, + RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_MAX_TOKENS, RECOMMENDED_TOP_K, RECOMMENDED_TOP_P, @@ -116,6 +121,10 @@ async def test_form(hass: HomeAssistant) -> None: CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_K: RECOMMENDED_TOP_K, CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, }, ), (