From f6ad018f8fa3ccd4bf31df47388d3c26532ecf6d Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Sat, 3 Aug 2024 09:14:24 +0300 Subject: [PATCH] Change enum type to string for Google Generative AI Conversation (#123069) --- .../conversation.py | 14 ++++- .../test_conversation.py | 59 +++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index a5c911bb757..6b3e87f8181 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -89,9 +89,9 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: key = "type_" val = val.upper() elif key == "format": - if (schema.get("type") == "string" and val != "enum") or ( - schema.get("type") not in ("number", "integer", "string") - ): + if schema.get("type") == "string" and val != "enum": + continue + if schema.get("type") not in ("number", "integer", "string"): continue key = "format_" elif key == "items": @@ -100,11 +100,19 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: val = {k: _format_schema(v) for k, v in val.items()} result[key] = val + if result.get("enum") and result.get("type_") != "STRING": + # enum is only allowed for STRING type. This is safe as long as the schema + # contains vol.Coerce for the respective type, for example: + # vol.All(vol.Coerce(int), vol.In([1, 2, 3])) + result["type_"] = "STRING" + result["enum"] = [str(item) for item in result["enum"]] + if result.get("type_") == "OBJECT" and not result.get("properties"): # An object with undefined properties is not supported by Gemini API. # Fallback to JSON string. This will probably fail for most tools that want it, # but we don't have a better fallback strategy so far. result["properties"] = {"json": {"type_": "STRING"}} + result["required"] = [] return result diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 41f96c7b0ac..df8472995a7 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -17,6 +17,7 @@ from homeassistant.components.google_generative_ai_conversation.const import ( ) from homeassistant.components.google_generative_ai_conversation.conversation import ( _escape_decode, + _format_schema, ) from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant @@ -622,3 +623,61 @@ async def test_escape_decode() -> None: "param2": "param2's value", "param3": {"param31": "Cheminée", "param32": "Cheminée"}, } + + +@pytest.mark.parametrize( + ("openapi", "protobuf"), + [ + ( + {"type": "string", "enum": ["a", "b", "c"]}, + {"type_": "STRING", "enum": ["a", "b", "c"]}, + ), + ( + {"type": "integer", "enum": [1, 2, 3]}, + {"type_": "STRING", "enum": ["1", "2", "3"]}, + ), + ({"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"type_": "INTEGER"}), + ( + { + "anyOf": [ + {"anyOf": [{"type": "integer"}, {"type": "number"}]}, + {"anyOf": [{"type": "integer"}, {"type": "number"}]}, + ] + }, + {"type_": "INTEGER"}, + ), + ({"type": "string", "format": "lower"}, {"type_": "STRING"}), + ({"type": "boolean", "format": "bool"}, {"type_": "BOOLEAN"}), + ( + {"type": "number", "format": "percent"}, + {"type_": "NUMBER", "format_": "percent"}, + ), + ( + { + "type": "object", + "properties": {"var": {"type": "string"}}, + "required": [], + }, + { + "type_": "OBJECT", + "properties": {"var": {"type_": "STRING"}}, + "required": [], + }, + ), + ( + {"type": "object", "additionalProperties": True}, + { + "type_": "OBJECT", + "properties": {"json": {"type_": "STRING"}}, + "required": [], + }, + ), + ( + {"type": "array", "items": {"type": "string"}}, + {"type_": "ARRAY", "items": {"type_": "STRING"}}, + ), + ], +) +async def test_format_schema(openapi, protobuf) -> None: + """Test _format_schema.""" + assert _format_schema(openapi) == protobuf