mirror of
https://github.com/home-assistant/core.git
synced 2025-07-07 21:37:07 +00:00
Allow an LLM to see script response values (#131683)
This commit is contained in:
parent
46fe3dcbf1
commit
7e03100af2
@ -22,15 +22,13 @@ from homeassistant.components.conversation import (
|
|||||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||||
from homeassistant.components.homeassistant import async_should_expose
|
from homeassistant.components.homeassistant import async_should_expose
|
||||||
from homeassistant.components.intent import async_device_supports_timers
|
from homeassistant.components.intent import async_device_supports_timers
|
||||||
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
|
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
|
||||||
from homeassistant.components.weather import INTENT_GET_WEATHER
|
from homeassistant.components.weather import INTENT_GET_WEATHER
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DOMAIN,
|
ATTR_DOMAIN,
|
||||||
ATTR_ENTITY_ID,
|
|
||||||
ATTR_SERVICE,
|
ATTR_SERVICE,
|
||||||
EVENT_HOMEASSISTANT_CLOSE,
|
EVENT_HOMEASSISTANT_CLOSE,
|
||||||
EVENT_SERVICE_REMOVED,
|
EVENT_SERVICE_REMOVED,
|
||||||
SERVICE_TURN_ON,
|
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@ -416,9 +414,7 @@ class AssistAPI(API):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
script_tool = ScriptTool(self.hass, state.entity_id)
|
tools.append(ScriptTool(self.hass, state.entity_id))
|
||||||
if script_tool.parameters.schema:
|
|
||||||
tools.append(script_tool)
|
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@ -702,10 +698,9 @@ class ScriptTool(Tool):
|
|||||||
script_entity_id: str,
|
script_entity_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Init the class."""
|
"""Init the class."""
|
||||||
self.name = split_entity_id(script_entity_id)[1]
|
self._object_id = self.name = split_entity_id(script_entity_id)[1]
|
||||||
if self.name[0].isdigit():
|
if self.name[0].isdigit():
|
||||||
self.name = "_" + self.name
|
self.name = "_" + self.name
|
||||||
self._entity_id = script_entity_id
|
|
||||||
|
|
||||||
self.description, self.parameters = _get_cached_script_parameters(
|
self.description, self.parameters = _get_cached_script_parameters(
|
||||||
hass, script_entity_id
|
hass, script_entity_id
|
||||||
@ -745,14 +740,13 @@ class ScriptTool(Tool):
|
|||||||
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
||||||
tool_input.tool_args[field] = floor
|
tool_input.tool_args[field] = floor
|
||||||
|
|
||||||
await hass.services.async_call(
|
result = await hass.services.async_call(
|
||||||
SCRIPT_DOMAIN,
|
SCRIPT_DOMAIN,
|
||||||
SERVICE_TURN_ON,
|
self._object_id,
|
||||||
{
|
tool_input.tool_args,
|
||||||
ATTR_ENTITY_ID: self._entity_id,
|
|
||||||
ATTR_VARIABLES: tool_input.tool_args,
|
|
||||||
},
|
|
||||||
context=llm_context.context,
|
context=llm_context.context,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"success": True}
|
return {"success": True, "result": result}
|
||||||
|
@ -656,7 +656,10 @@ async def test_script_tool(
|
|||||||
"script": {
|
"script": {
|
||||||
"test_script": {
|
"test_script": {
|
||||||
"description": "This is a test script",
|
"description": "This is a test script",
|
||||||
"sequence": [],
|
"sequence": [
|
||||||
|
{"variables": {"result": {"drinks": 2}}},
|
||||||
|
{"stop": True, "response_variable": "result"},
|
||||||
|
],
|
||||||
"fields": {
|
"fields": {
|
||||||
"beer": {"description": "Number of beers", "required": True},
|
"beer": {"description": "Number of beers", "required": True},
|
||||||
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
|
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
|
||||||
@ -692,7 +695,7 @@ async def test_script_tool(
|
|||||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||||
assert len(tools) == 1
|
assert len(tools) == 2
|
||||||
|
|
||||||
tool = tools[0]
|
tool = tools[0]
|
||||||
assert tool.name == "test_script"
|
assert tool.name == "test_script"
|
||||||
@ -719,6 +722,7 @@ async def test_script_tool(
|
|||||||
"script_with_no_fields": ("This is another test script", vol.Schema({})),
|
"script_with_no_fields": ("This is another test script", vol.Schema({})),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Test script with response
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name="test_script",
|
tool_name="test_script",
|
||||||
tool_args={
|
tool_args={
|
||||||
@ -731,26 +735,56 @@ async def test_script_tool(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
|
with patch(
|
||||||
|
"homeassistant.core.ServiceRegistry.async_call",
|
||||||
|
side_effect=hass.services.async_call,
|
||||||
|
) as mock_service_call:
|
||||||
response = await api.async_call_tool(tool_input)
|
response = await api.async_call_tool(tool_input)
|
||||||
|
|
||||||
mock_service_call.assert_awaited_once_with(
|
mock_service_call.assert_awaited_once_with(
|
||||||
"script",
|
"script",
|
||||||
"turn_on",
|
"test_script",
|
||||||
{
|
{
|
||||||
"entity_id": "script.test_script",
|
"beer": "3",
|
||||||
"variables": {
|
"wine": 0,
|
||||||
"beer": "3",
|
"where": area.id,
|
||||||
"wine": 0,
|
"area_list": [area.id],
|
||||||
"where": area.id,
|
"floor": floor.floor_id,
|
||||||
"area_list": [area.id],
|
"floor_list": [floor.floor_id],
|
||||||
"floor": floor.floor_id,
|
|
||||||
"floor_list": [floor.floor_id],
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
context=context,
|
context=context,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
)
|
)
|
||||||
assert response == {"success": True}
|
assert response == {
|
||||||
|
"success": True,
|
||||||
|
"result": {"drinks": 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test script with no response
|
||||||
|
tool_input = llm.ToolInput(
|
||||||
|
tool_name="script_with_no_fields",
|
||||||
|
tool_args={},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.core.ServiceRegistry.async_call",
|
||||||
|
side_effect=hass.services.async_call,
|
||||||
|
) as mock_service_call:
|
||||||
|
response = await api.async_call_tool(tool_input)
|
||||||
|
|
||||||
|
mock_service_call.assert_awaited_once_with(
|
||||||
|
"script",
|
||||||
|
"script_with_no_fields",
|
||||||
|
{},
|
||||||
|
context=context,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
assert response == {
|
||||||
|
"success": True,
|
||||||
|
"result": {},
|
||||||
|
}
|
||||||
|
|
||||||
# Test reload script with new parameters
|
# Test reload script with new parameters
|
||||||
config = {
|
config = {
|
||||||
@ -782,7 +816,7 @@ async def test_script_tool(
|
|||||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||||
assert len(tools) == 1
|
assert len(tools) == 2
|
||||||
|
|
||||||
tool = tools[0]
|
tool = tools[0]
|
||||||
assert tool.name == "test_script"
|
assert tool.name == "test_script"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user