mirror of
https://github.com/home-assistant/core.git
synced 2025-07-07 13:27:09 +00:00
Allow an LLM to see script response values (#131683)
This commit is contained in:
parent
46fe3dcbf1
commit
7e03100af2
@ -22,15 +22,13 @@ from homeassistant.components.conversation import (
|
||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.homeassistant import async_should_expose
|
||||
from homeassistant.components.intent import async_device_supports_timers
|
||||
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
|
||||
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
|
||||
from homeassistant.components.weather import INTENT_GET_WEATHER
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN,
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_SERVICE,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_SERVICE_REMOVED,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@ -416,9 +414,7 @@ class AssistAPI(API):
|
||||
):
|
||||
continue
|
||||
|
||||
script_tool = ScriptTool(self.hass, state.entity_id)
|
||||
if script_tool.parameters.schema:
|
||||
tools.append(script_tool)
|
||||
tools.append(ScriptTool(self.hass, state.entity_id))
|
||||
|
||||
return tools
|
||||
|
||||
@ -702,10 +698,9 @@ class ScriptTool(Tool):
|
||||
script_entity_id: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
self.name = split_entity_id(script_entity_id)[1]
|
||||
self._object_id = self.name = split_entity_id(script_entity_id)[1]
|
||||
if self.name[0].isdigit():
|
||||
self.name = "_" + self.name
|
||||
self._entity_id = script_entity_id
|
||||
|
||||
self.description, self.parameters = _get_cached_script_parameters(
|
||||
hass, script_entity_id
|
||||
@ -745,14 +740,13 @@ class ScriptTool(Tool):
|
||||
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
||||
tool_input.tool_args[field] = floor
|
||||
|
||||
await hass.services.async_call(
|
||||
result = await hass.services.async_call(
|
||||
SCRIPT_DOMAIN,
|
||||
SERVICE_TURN_ON,
|
||||
{
|
||||
ATTR_ENTITY_ID: self._entity_id,
|
||||
ATTR_VARIABLES: tool_input.tool_args,
|
||||
},
|
||||
self._object_id,
|
||||
tool_input.tool_args,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
return {"success": True}
|
||||
return {"success": True, "result": result}
|
||||
|
@ -656,7 +656,10 @@ async def test_script_tool(
|
||||
"script": {
|
||||
"test_script": {
|
||||
"description": "This is a test script",
|
||||
"sequence": [],
|
||||
"sequence": [
|
||||
{"variables": {"result": {"drinks": 2}}},
|
||||
{"stop": True, "response_variable": "result"},
|
||||
],
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers", "required": True},
|
||||
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
|
||||
@ -692,7 +695,7 @@ async def test_script_tool(
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
assert len(tools) == 2
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_script"
|
||||
@ -719,6 +722,7 @@ async def test_script_tool(
|
||||
"script_with_no_fields": ("This is another test script", vol.Schema({})),
|
||||
}
|
||||
|
||||
# Test script with response
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="test_script",
|
||||
tool_args={
|
||||
@ -731,26 +735,56 @@ async def test_script_tool(
|
||||
},
|
||||
)
|
||||
|
||||
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
|
||||
with patch(
|
||||
"homeassistant.core.ServiceRegistry.async_call",
|
||||
side_effect=hass.services.async_call,
|
||||
) as mock_service_call:
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
mock_service_call.assert_awaited_once_with(
|
||||
"script",
|
||||
"turn_on",
|
||||
"test_script",
|
||||
{
|
||||
"entity_id": "script.test_script",
|
||||
"variables": {
|
||||
"beer": "3",
|
||||
"wine": 0,
|
||||
"where": area.id,
|
||||
"area_list": [area.id],
|
||||
"floor": floor.floor_id,
|
||||
"floor_list": [floor.floor_id],
|
||||
},
|
||||
"beer": "3",
|
||||
"wine": 0,
|
||||
"where": area.id,
|
||||
"area_list": [area.id],
|
||||
"floor": floor.floor_id,
|
||||
"floor_list": [floor.floor_id],
|
||||
},
|
||||
context=context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert response == {"success": True}
|
||||
assert response == {
|
||||
"success": True,
|
||||
"result": {"drinks": 2},
|
||||
}
|
||||
|
||||
# Test script with no response
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="script_with_no_fields",
|
||||
tool_args={},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.core.ServiceRegistry.async_call",
|
||||
side_effect=hass.services.async_call,
|
||||
) as mock_service_call:
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
mock_service_call.assert_awaited_once_with(
|
||||
"script",
|
||||
"script_with_no_fields",
|
||||
{},
|
||||
context=context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert response == {
|
||||
"success": True,
|
||||
"result": {},
|
||||
}
|
||||
|
||||
# Test reload script with new parameters
|
||||
config = {
|
||||
@ -782,7 +816,7 @@ async def test_script_tool(
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
assert len(tools) == 2
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_script"
|
||||
|
Loading…
x
Reference in New Issue
Block a user