Change enum type to string for Google Generative AI Conversation (#123069)

This commit is contained in:
Denis Shulyaka 2024-08-03 09:14:24 +03:00 committed by GitHub
parent 8687c32c15
commit f6ad018f8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 3 deletions

View File

@ -89,9 +89,9 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
key = "type_"
val = val.upper()
elif key == "format":
if (schema.get("type") == "string" and val != "enum") or (
schema.get("type") not in ("number", "integer", "string")
):
if schema.get("type") == "string" and val != "enum":
continue
if schema.get("type") not in ("number", "integer", "string"):
continue
key = "format_"
elif key == "items":
@ -100,11 +100,19 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val
if result.get("enum") and result.get("type_") != "STRING":
# enum is only allowed for STRING type. This is safe as long as the schema
# contains vol.Coerce for the respective type, for example:
# vol.All(vol.Coerce(int), vol.In([1, 2, 3]))
result["type_"] = "STRING"
result["enum"] = [str(item) for item in result["enum"]]
if result.get("type_") == "OBJECT" and not result.get("properties"):
# An object with undefined properties is not supported by Gemini API.
# Fallback to JSON string. This will probably fail for most tools that want it,
# but we don't have a better fallback strategy so far.
result["properties"] = {"json": {"type_": "STRING"}}
result["required"] = []
return result

View File

@ -17,6 +17,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
)
from homeassistant.components.google_generative_ai_conversation.conversation import (
_escape_decode,
_format_schema,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
@ -622,3 +623,61 @@ async def test_escape_decode() -> None:
"param2": "param2's value",
"param3": {"param31": "Cheminée", "param32": "Cheminée"},
}
@pytest.mark.parametrize(
("openapi", "protobuf"),
[
(
{"type": "string", "enum": ["a", "b", "c"]},
{"type_": "STRING", "enum": ["a", "b", "c"]},
),
(
{"type": "integer", "enum": [1, 2, 3]},
{"type_": "STRING", "enum": ["1", "2", "3"]},
),
({"anyOf": [{"type": "integer"}, {"type": "number"}]}, {"type_": "INTEGER"}),
(
{
"anyOf": [
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
{"anyOf": [{"type": "integer"}, {"type": "number"}]},
]
},
{"type_": "INTEGER"},
),
({"type": "string", "format": "lower"}, {"type_": "STRING"}),
({"type": "boolean", "format": "bool"}, {"type_": "BOOLEAN"}),
(
{"type": "number", "format": "percent"},
{"type_": "NUMBER", "format_": "percent"},
),
(
{
"type": "object",
"properties": {"var": {"type": "string"}},
"required": [],
},
{
"type_": "OBJECT",
"properties": {"var": {"type_": "STRING"}},
"required": [],
},
),
(
{"type": "object", "additionalProperties": True},
{
"type_": "OBJECT",
"properties": {"json": {"type_": "STRING"}},
"required": [],
},
),
(
{"type": "array", "items": {"type": "string"}},
{"type_": "ARRAY", "items": {"type_": "STRING"}},
),
],
)
async def test_format_schema(openapi, protobuf) -> None:
"""Test _format_schema."""
assert _format_schema(openapi) == protobuf