Allow an LLM to see script response values (#131683)

This commit is contained in:
Paulus Schoutsen 2024-11-27 00:51:21 -05:00 committed by GitHub
parent 46fe3dcbf1
commit 7e03100af2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 30 deletions

View File

@ -22,15 +22,13 @@ from homeassistant.components.conversation import (
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant import async_should_expose
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.weather import INTENT_GET_WEATHER
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
SERVICE_TURN_ON,
)
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError
@ -416,9 +414,7 @@ class AssistAPI(API):
):
continue
script_tool = ScriptTool(self.hass, state.entity_id)
if script_tool.parameters.schema:
tools.append(script_tool)
tools.append(ScriptTool(self.hass, state.entity_id))
return tools
@ -702,10 +698,9 @@ class ScriptTool(Tool):
script_entity_id: str,
) -> None:
"""Init the class."""
self.name = split_entity_id(script_entity_id)[1]
self._object_id = self.name = split_entity_id(script_entity_id)[1]
if self.name[0].isdigit():
self.name = "_" + self.name
self._entity_id = script_entity_id
self.description, self.parameters = _get_cached_script_parameters(
hass, script_entity_id
@ -745,14 +740,13 @@ class ScriptTool(Tool):
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
tool_input.tool_args[field] = floor
await hass.services.async_call(
result = await hass.services.async_call(
SCRIPT_DOMAIN,
SERVICE_TURN_ON,
{
ATTR_ENTITY_ID: self._entity_id,
ATTR_VARIABLES: tool_input.tool_args,
},
self._object_id,
tool_input.tool_args,
context=llm_context.context,
blocking=True,
return_response=True,
)
return {"success": True}
return {"success": True, "result": result}

View File

@ -656,7 +656,10 @@ async def test_script_tool(
"script": {
"test_script": {
"description": "This is a test script",
"sequence": [],
"sequence": [
{"variables": {"result": {"drinks": 2}}},
{"stop": True, "response_variable": "result"},
],
"fields": {
"beer": {"description": "Number of beers", "required": True},
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
@ -692,7 +695,7 @@ async def test_script_tool(
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
assert len(tools) == 2
tool = tools[0]
assert tool.name == "test_script"
@ -719,6 +722,7 @@ async def test_script_tool(
"script_with_no_fields": ("This is another test script", vol.Schema({})),
}
# Test script with response
tool_input = llm.ToolInput(
tool_name="test_script",
tool_args={
@ -731,26 +735,56 @@ async def test_script_tool(
},
)
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
with patch(
"homeassistant.core.ServiceRegistry.async_call",
side_effect=hass.services.async_call,
) as mock_service_call:
response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with(
"script",
"turn_on",
"test_script",
{
"entity_id": "script.test_script",
"variables": {
"beer": "3",
"wine": 0,
"where": area.id,
"area_list": [area.id],
"floor": floor.floor_id,
"floor_list": [floor.floor_id],
},
"beer": "3",
"wine": 0,
"where": area.id,
"area_list": [area.id],
"floor": floor.floor_id,
"floor_list": [floor.floor_id],
},
context=context,
blocking=True,
return_response=True,
)
assert response == {"success": True}
assert response == {
"success": True,
"result": {"drinks": 2},
}
# Test script with no response
tool_input = llm.ToolInput(
tool_name="script_with_no_fields",
tool_args={},
)
with patch(
"homeassistant.core.ServiceRegistry.async_call",
side_effect=hass.services.async_call,
) as mock_service_call:
response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with(
"script",
"script_with_no_fields",
{},
context=context,
blocking=True,
return_response=True,
)
assert response == {
"success": True,
"result": {},
}
# Test reload script with new parameters
config = {
@ -782,7 +816,7 @@ async def test_script_tool(
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
assert len(tools) == 2
tool = tools[0]
assert tool.name == "test_script"