Allow an LLM to see script response values (#131683)

This commit is contained in:
Paulus Schoutsen 2024-11-27 00:51:21 -05:00 committed by GitHub
parent 46fe3dcbf1
commit 7e03100af2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 30 deletions

View File

@ -22,15 +22,13 @@ from homeassistant.components.conversation import (
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant import async_should_expose from homeassistant.components.homeassistant import async_should_expose
from homeassistant.components.intent import async_device_supports_timers from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.weather import INTENT_GET_WEATHER from homeassistant.components.weather import INTENT_GET_WEATHER
from homeassistant.const import ( from homeassistant.const import (
ATTR_DOMAIN, ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_SERVICE, ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED, EVENT_SERVICE_REMOVED,
SERVICE_TURN_ON,
) )
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -416,9 +414,7 @@ class AssistAPI(API):
): ):
continue continue
script_tool = ScriptTool(self.hass, state.entity_id) tools.append(ScriptTool(self.hass, state.entity_id))
if script_tool.parameters.schema:
tools.append(script_tool)
return tools return tools
@ -702,10 +698,9 @@ class ScriptTool(Tool):
script_entity_id: str, script_entity_id: str,
) -> None: ) -> None:
"""Init the class.""" """Init the class."""
self.name = split_entity_id(script_entity_id)[1] self._object_id = self.name = split_entity_id(script_entity_id)[1]
if self.name[0].isdigit(): if self.name[0].isdigit():
self.name = "_" + self.name self.name = "_" + self.name
self._entity_id = script_entity_id
self.description, self.parameters = _get_cached_script_parameters( self.description, self.parameters = _get_cached_script_parameters(
hass, script_entity_id hass, script_entity_id
@ -745,14 +740,13 @@ class ScriptTool(Tool):
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
tool_input.tool_args[field] = floor tool_input.tool_args[field] = floor
await hass.services.async_call( result = await hass.services.async_call(
SCRIPT_DOMAIN, SCRIPT_DOMAIN,
SERVICE_TURN_ON, self._object_id,
{ tool_input.tool_args,
ATTR_ENTITY_ID: self._entity_id,
ATTR_VARIABLES: tool_input.tool_args,
},
context=llm_context.context, context=llm_context.context,
blocking=True,
return_response=True,
) )
return {"success": True} return {"success": True, "result": result}

View File

@ -656,7 +656,10 @@ async def test_script_tool(
"script": { "script": {
"test_script": { "test_script": {
"description": "This is a test script", "description": "This is a test script",
"sequence": [], "sequence": [
{"variables": {"result": {"drinks": 2}}},
{"stop": True, "response_variable": "result"},
],
"fields": { "fields": {
"beer": {"description": "Number of beers", "required": True}, "beer": {"description": "Number of beers", "required": True},
"wine": {"selector": {"number": {"min": 0, "max": 3}}}, "wine": {"selector": {"number": {"min": 0, "max": 3}}},
@ -692,7 +695,7 @@ async def test_script_tool(
api = await llm.async_get_api(hass, "assist", llm_context) api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)] tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1 assert len(tools) == 2
tool = tools[0] tool = tools[0]
assert tool.name == "test_script" assert tool.name == "test_script"
@ -719,6 +722,7 @@ async def test_script_tool(
"script_with_no_fields": ("This is another test script", vol.Schema({})), "script_with_no_fields": ("This is another test script", vol.Schema({})),
} }
# Test script with response
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name="test_script", tool_name="test_script",
tool_args={ tool_args={
@ -731,26 +735,56 @@ async def test_script_tool(
}, },
) )
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call: with patch(
"homeassistant.core.ServiceRegistry.async_call",
side_effect=hass.services.async_call,
) as mock_service_call:
response = await api.async_call_tool(tool_input) response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with( mock_service_call.assert_awaited_once_with(
"script", "script",
"turn_on", "test_script",
{ {
"entity_id": "script.test_script", "beer": "3",
"variables": { "wine": 0,
"beer": "3", "where": area.id,
"wine": 0, "area_list": [area.id],
"where": area.id, "floor": floor.floor_id,
"area_list": [area.id], "floor_list": [floor.floor_id],
"floor": floor.floor_id,
"floor_list": [floor.floor_id],
},
}, },
context=context, context=context,
blocking=True,
return_response=True,
) )
assert response == {"success": True} assert response == {
"success": True,
"result": {"drinks": 2},
}
# Test script with no response
tool_input = llm.ToolInput(
tool_name="script_with_no_fields",
tool_args={},
)
with patch(
"homeassistant.core.ServiceRegistry.async_call",
side_effect=hass.services.async_call,
) as mock_service_call:
response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with(
"script",
"script_with_no_fields",
{},
context=context,
blocking=True,
return_response=True,
)
assert response == {
"success": True,
"result": {},
}
# Test reload script with new parameters # Test reload script with new parameters
config = { config = {
@ -782,7 +816,7 @@ async def test_script_tool(
api = await llm.async_get_api(hass, "assist", llm_context) api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)] tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1 assert len(tools) == 2
tool = tools[0] tool = tools[0]
assert tool.name == "test_script" assert tool.name == "test_script"