mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Properly handle escaped unicode characters passed to tools in Google Generative AI (#119117)
This commit is contained in:
parent
f07e7ec543
commit
f605c10f42
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import codecs
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from google.api_core.exceptions import GoogleAPICallError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
@ -106,14 +107,14 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _adjust_value(value: Any) -> Any:
|
def _escape_decode(value: Any) -> Any:
|
||||||
"""Reverse unnecessary single quotes escaping."""
|
"""Recursively call codecs.escape_decode on all values."""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value.replace("\\'", "'")
|
return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined]
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return [_adjust_value(item) for item in value]
|
return [_escape_decode(item) for item in value]
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return {k: _adjust_value(v) for k, v in value.items()}
|
return {k: _escape_decode(v) for k, v in value.items()}
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -334,10 +335,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
for function_call in function_calls:
|
for function_call in function_calls:
|
||||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||||
tool_name = tool_call["name"]
|
tool_name = tool_call["name"]
|
||||||
tool_args = {
|
tool_args = _escape_decode(tool_call["args"])
|
||||||
key: _adjust_value(value)
|
|
||||||
for key, value in tool_call["args"].items()
|
|
||||||
}
|
|
||||||
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
||||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||||
try:
|
try:
|
||||||
|
@ -12,6 +12,9 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
|
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||||
|
_escape_decode,
|
||||||
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@ -504,3 +507,18 @@ async def test_conversation_agent(
|
|||||||
mock_config_entry.entry_id
|
mock_config_entry.entry_id
|
||||||
)
|
)
|
||||||
assert agent.supported_languages == "*"
|
assert agent.supported_languages == "*"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_escape_decode() -> None:
|
||||||
|
"""Test _escape_decode."""
|
||||||
|
assert _escape_decode(
|
||||||
|
{
|
||||||
|
"param1": ["test_value", "param1\\'s value"],
|
||||||
|
"param2": "param2\\'s value",
|
||||||
|
"param3": {"param31": "Cheminée", "param32": "Chemin\\303\\251e"},
|
||||||
|
}
|
||||||
|
) == {
|
||||||
|
"param1": ["test_value", "param1's value"],
|
||||||
|
"param2": "param2's value",
|
||||||
|
"param3": {"param31": "Cheminée", "param32": "Cheminée"},
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user