mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 00:37:13 +00:00
Rename llm.ToolContext to llm.LLMContext (#118566)
This commit is contained in:
parent
6656f7d6b9
commit
6dd01dbff7
@ -169,7 +169,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
llm_api = await llm.async_get_api(
|
llm_api = await llm.async_get_api(
|
||||||
self.hass,
|
self.hass,
|
||||||
self.entry.options[CONF_LLM_HASS_API],
|
self.entry.options[CONF_LLM_HASS_API],
|
||||||
llm.ToolContext(
|
llm.LLMContext(
|
||||||
platform=DOMAIN,
|
platform=DOMAIN,
|
||||||
context=user_input.context,
|
context=user_input.context,
|
||||||
user_prompt=user_input.text,
|
user_prompt=user_input.text,
|
||||||
|
@ -119,7 +119,7 @@ class OpenAIConversationEntity(
|
|||||||
llm_api = await llm.async_get_api(
|
llm_api = await llm.async_get_api(
|
||||||
self.hass,
|
self.hass,
|
||||||
options[CONF_LLM_HASS_API],
|
options[CONF_LLM_HASS_API],
|
||||||
llm.ToolContext(
|
llm.LLMContext(
|
||||||
platform=DOMAIN,
|
platform=DOMAIN,
|
||||||
context=user_input.context,
|
context=user_input.context,
|
||||||
user_prompt=user_input.text,
|
user_prompt=user_input.text,
|
||||||
|
@ -71,7 +71,7 @@ def async_register_api(hass: HomeAssistant, api: API) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def async_get_api(
|
async def async_get_api(
|
||||||
hass: HomeAssistant, api_id: str, tool_context: ToolContext
|
hass: HomeAssistant, api_id: str, llm_context: LLMContext
|
||||||
) -> APIInstance:
|
) -> APIInstance:
|
||||||
"""Get an API."""
|
"""Get an API."""
|
||||||
apis = _async_get_apis(hass)
|
apis = _async_get_apis(hass)
|
||||||
@ -79,7 +79,7 @@ async def async_get_api(
|
|||||||
if api_id not in apis:
|
if api_id not in apis:
|
||||||
raise HomeAssistantError(f"API {api_id} not found")
|
raise HomeAssistantError(f"API {api_id} not found")
|
||||||
|
|
||||||
return await apis[api_id].async_get_api_instance(tool_context)
|
return await apis[api_id].async_get_api_instance(llm_context)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -89,7 +89,7 @@ def async_get_apis(hass: HomeAssistant) -> list[API]:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class ToolContext:
|
class LLMContext:
|
||||||
"""Tool input to be processed."""
|
"""Tool input to be processed."""
|
||||||
|
|
||||||
platform: str
|
platform: str
|
||||||
@ -117,7 +117,7 @@ class Tool:
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
) -> JsonObjectType:
|
) -> JsonObjectType:
|
||||||
"""Call the tool."""
|
"""Call the tool."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -133,7 +133,7 @@ class APIInstance:
|
|||||||
|
|
||||||
api: API
|
api: API
|
||||||
api_prompt: str
|
api_prompt: str
|
||||||
tool_context: ToolContext
|
llm_context: LLMContext
|
||||||
tools: list[Tool]
|
tools: list[Tool]
|
||||||
|
|
||||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||||
@ -149,7 +149,7 @@ class APIInstance:
|
|||||||
else:
|
else:
|
||||||
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
||||||
|
|
||||||
return await tool.async_call(self.api.hass, tool_input, self.tool_context)
|
return await tool.async_call(self.api.hass, tool_input, self.llm_context)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True, kw_only=True)
|
@dataclass(slots=True, kw_only=True)
|
||||||
@ -161,7 +161,7 @@ class API(ABC):
|
|||||||
name: str
|
name: str
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance:
|
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
|
||||||
"""Return the instance of the API."""
|
"""Return the instance of the API."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -182,20 +182,20 @@ class IntentTool(Tool):
|
|||||||
self.parameters = vol.Schema(slot_schema)
|
self.parameters = vol.Schema(slot_schema)
|
||||||
|
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
) -> JsonObjectType:
|
) -> JsonObjectType:
|
||||||
"""Handle the intent."""
|
"""Handle the intent."""
|
||||||
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
||||||
intent_response = await intent.async_handle(
|
intent_response = await intent.async_handle(
|
||||||
hass=hass,
|
hass=hass,
|
||||||
platform=tool_context.platform,
|
platform=llm_context.platform,
|
||||||
intent_type=self.name,
|
intent_type=self.name,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
text_input=tool_context.user_prompt,
|
text_input=llm_context.user_prompt,
|
||||||
context=tool_context.context,
|
context=llm_context.context,
|
||||||
language=tool_context.language,
|
language=llm_context.language,
|
||||||
assistant=tool_context.assistant,
|
assistant=llm_context.assistant,
|
||||||
device_id=tool_context.device_id,
|
device_id=llm_context.device_id,
|
||||||
)
|
)
|
||||||
response = intent_response.as_dict()
|
response = intent_response.as_dict()
|
||||||
del response["language"]
|
del response["language"]
|
||||||
@ -224,25 +224,25 @@ class AssistAPI(API):
|
|||||||
name="Assist",
|
name="Assist",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance:
|
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
|
||||||
"""Return the instance of the API."""
|
"""Return the instance of the API."""
|
||||||
if tool_context.assistant:
|
if llm_context.assistant:
|
||||||
exposed_entities: dict | None = _get_exposed_entities(
|
exposed_entities: dict | None = _get_exposed_entities(
|
||||||
self.hass, tool_context.assistant
|
self.hass, llm_context.assistant
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exposed_entities = None
|
exposed_entities = None
|
||||||
|
|
||||||
return APIInstance(
|
return APIInstance(
|
||||||
api=self,
|
api=self,
|
||||||
api_prompt=self._async_get_api_prompt(tool_context, exposed_entities),
|
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
||||||
tool_context=tool_context,
|
llm_context=llm_context,
|
||||||
tools=self._async_get_tools(tool_context, exposed_entities),
|
tools=self._async_get_tools(llm_context, exposed_entities),
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_api_prompt(
|
def _async_get_api_prompt(
|
||||||
self, tool_context: ToolContext, exposed_entities: dict | None
|
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the prompt for the API."""
|
"""Return the prompt for the API."""
|
||||||
if not exposed_entities:
|
if not exposed_entities:
|
||||||
@ -263,9 +263,9 @@ class AssistAPI(API):
|
|||||||
]
|
]
|
||||||
area: ar.AreaEntry | None = None
|
area: ar.AreaEntry | None = None
|
||||||
floor: fr.FloorEntry | None = None
|
floor: fr.FloorEntry | None = None
|
||||||
if tool_context.device_id:
|
if llm_context.device_id:
|
||||||
device_reg = dr.async_get(self.hass)
|
device_reg = dr.async_get(self.hass)
|
||||||
device = device_reg.async_get(tool_context.device_id)
|
device = device_reg.async_get(llm_context.device_id)
|
||||||
|
|
||||||
if device:
|
if device:
|
||||||
area_reg = ar.async_get(self.hass)
|
area_reg = ar.async_get(self.hass)
|
||||||
@ -286,8 +286,8 @@ class AssistAPI(API):
|
|||||||
"ask user to specify an area, unless there is only one device of that type."
|
"ask user to specify an area, unless there is only one device of that type."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not tool_context.device_id or not async_device_supports_timers(
|
if not llm_context.device_id or not async_device_supports_timers(
|
||||||
self.hass, tool_context.device_id
|
self.hass, llm_context.device_id
|
||||||
):
|
):
|
||||||
prompt.append("This device does not support timers.")
|
prompt.append("This device does not support timers.")
|
||||||
|
|
||||||
@ -301,12 +301,12 @@ class AssistAPI(API):
|
|||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_get_tools(
|
def _async_get_tools(
|
||||||
self, tool_context: ToolContext, exposed_entities: dict | None
|
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||||
) -> list[Tool]:
|
) -> list[Tool]:
|
||||||
"""Return a list of LLM tools."""
|
"""Return a list of LLM tools."""
|
||||||
ignore_intents = self.IGNORE_INTENTS
|
ignore_intents = self.IGNORE_INTENTS
|
||||||
if not tool_context.device_id or not async_device_supports_timers(
|
if not llm_context.device_id or not async_device_supports_timers(
|
||||||
self.hass, tool_context.device_id
|
self.hass, llm_context.device_id
|
||||||
):
|
):
|
||||||
ignore_intents = ignore_intents | {
|
ignore_intents = ignore_intents | {
|
||||||
intent.INTENT_START_TIMER,
|
intent.INTENT_START_TIMER,
|
||||||
|
@ -231,7 +231,7 @@ async def test_function_call(
|
|||||||
"param2": "param2's value",
|
"param2": "param2's value",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
llm.ToolContext(
|
llm.LLMContext(
|
||||||
platform="google_generative_ai_conversation",
|
platform="google_generative_ai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
@ -330,7 +330,7 @@ async def test_function_exception(
|
|||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": 1},
|
tool_args={"param1": 1},
|
||||||
),
|
),
|
||||||
llm.ToolContext(
|
llm.LLMContext(
|
||||||
platform="google_generative_ai_conversation",
|
platform="google_generative_ai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
|
@ -192,7 +192,7 @@ async def test_function_call(
|
|||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
llm.ToolContext(
|
llm.LLMContext(
|
||||||
platform="openai_conversation",
|
platform="openai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
@ -324,7 +324,7 @@ async def test_function_exception(
|
|||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
),
|
),
|
||||||
llm.ToolContext(
|
llm.LLMContext(
|
||||||
platform="openai_conversation",
|
platform="openai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
|
@ -24,9 +24,9 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tool_input_context() -> llm.ToolContext:
|
def llm_context() -> llm.LLMContext:
|
||||||
"""Return tool input context."""
|
"""Return tool input context."""
|
||||||
return llm.ToolContext(
|
return llm.LLMContext(
|
||||||
platform="",
|
platform="",
|
||||||
context=None,
|
context=None,
|
||||||
user_prompt=None,
|
user_prompt=None,
|
||||||
@ -37,29 +37,27 @@ def tool_input_context() -> llm.ToolContext:
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_api_no_existing(
|
async def test_get_api_no_existing(
|
||||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting an llm api where no config exists."""
|
"""Test getting an llm api where no config exists."""
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
await llm.async_get_api(hass, "non-existing", tool_input_context)
|
await llm.async_get_api(hass, "non-existing", llm_context)
|
||||||
|
|
||||||
|
|
||||||
async def test_register_api(
|
async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
|
||||||
) -> None:
|
|
||||||
"""Test registering an llm api."""
|
"""Test registering an llm api."""
|
||||||
|
|
||||||
class MyAPI(llm.API):
|
class MyAPI(llm.API):
|
||||||
async def async_get_api_instance(
|
async def async_get_api_instance(
|
||||||
self, tool_input: llm.ToolInput
|
self, tool_context: llm.ToolInput
|
||||||
) -> llm.APIInstance:
|
) -> llm.APIInstance:
|
||||||
"""Return a list of tools."""
|
"""Return a list of tools."""
|
||||||
return llm.APIInstance(self, "", [], tool_input_context)
|
return llm.APIInstance(self, "", [], llm_context)
|
||||||
|
|
||||||
api = MyAPI(hass=hass, id="test", name="Test")
|
api = MyAPI(hass=hass, id="test", name="Test")
|
||||||
llm.async_register_api(hass, api)
|
llm.async_register_api(hass, api)
|
||||||
|
|
||||||
instance = await llm.async_get_api(hass, "test", tool_input_context)
|
instance = await llm.async_get_api(hass, "test", llm_context)
|
||||||
assert instance.api is api
|
assert instance.api is api
|
||||||
assert api in llm.async_get_apis(hass)
|
assert api in llm.async_get_apis(hass)
|
||||||
|
|
||||||
@ -68,10 +66,10 @@ async def test_register_api(
|
|||||||
|
|
||||||
|
|
||||||
async def test_call_tool_no_existing(
|
async def test_call_tool_no_existing(
|
||||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test calling an llm tool where no config exists."""
|
"""Test calling an llm tool where no config exists."""
|
||||||
instance = await llm.async_get_api(hass, "assist", tool_input_context)
|
instance = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
await instance.async_call_tool(
|
await instance.async_call_tool(
|
||||||
llm.ToolInput("test_tool", {}),
|
llm.ToolInput("test_tool", {}),
|
||||||
@ -93,7 +91,7 @@ async def test_assist_api(
|
|||||||
).write_unavailable_state(hass)
|
).write_unavailable_state(hass)
|
||||||
|
|
||||||
test_context = Context()
|
test_context = Context()
|
||||||
tool_context = llm.ToolContext(
|
llm_context = llm.LLMContext(
|
||||||
platform="test_platform",
|
platform="test_platform",
|
||||||
context=test_context,
|
context=test_context,
|
||||||
user_prompt="test_text",
|
user_prompt="test_text",
|
||||||
@ -116,19 +114,19 @@ async def test_assist_api(
|
|||||||
intent.async_register(hass, intent_handler)
|
intent.async_register(hass, intent_handler)
|
||||||
|
|
||||||
assert len(llm.async_get_apis(hass)) == 1
|
assert len(llm.async_get_apis(hass)) == 1
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert len(api.tools) == 0
|
assert len(api.tools) == 0
|
||||||
|
|
||||||
# Match all
|
# Match all
|
||||||
intent_handler.platforms = None
|
intent_handler.platforms = None
|
||||||
|
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert len(api.tools) == 1
|
assert len(api.tools) == 1
|
||||||
|
|
||||||
# Match specific domain
|
# Match specific domain
|
||||||
intent_handler.platforms = {"light"}
|
intent_handler.platforms = {"light"}
|
||||||
|
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert len(api.tools) == 1
|
assert len(api.tools) == 1
|
||||||
tool = api.tools[0]
|
tool = api.tools[0]
|
||||||
assert tool.name == "test_intent"
|
assert tool.name == "test_intent"
|
||||||
@ -176,25 +174,25 @@ async def test_assist_api(
|
|||||||
|
|
||||||
|
|
||||||
async def test_assist_api_get_timer_tools(
|
async def test_assist_api_get_timer_tools(
|
||||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting timer tools with Assist API."""
|
"""Test getting timer tools with Assist API."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
assert await async_setup_component(hass, "intent", {})
|
assert await async_setup_component(hass, "intent", {})
|
||||||
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
assert "HassStartTimer" not in [tool.name for tool in api.tools]
|
assert "HassStartTimer" not in [tool.name for tool in api.tools]
|
||||||
|
|
||||||
tool_input_context.device_id = "test_device"
|
llm_context.device_id = "test_device"
|
||||||
|
|
||||||
async_register_timer_handler(hass, "test_device", lambda *args: None)
|
async_register_timer_handler(hass, "test_device", lambda *args: None)
|
||||||
|
|
||||||
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert "HassStartTimer" in [tool.name for tool in api.tools]
|
assert "HassStartTimer" in [tool.name for tool in api.tools]
|
||||||
|
|
||||||
|
|
||||||
async def test_assist_api_description(
|
async def test_assist_api_description(
|
||||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test intent description with Assist API."""
|
"""Test intent description with Assist API."""
|
||||||
|
|
||||||
@ -205,7 +203,7 @@ async def test_assist_api_description(
|
|||||||
intent.async_register(hass, MyIntentHandler())
|
intent.async_register(hass, MyIntentHandler())
|
||||||
|
|
||||||
assert len(llm.async_get_apis(hass)) == 1
|
assert len(llm.async_get_apis(hass)) == 1
|
||||||
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert len(api.tools) == 1
|
assert len(api.tools) == 1
|
||||||
tool = api.tools[0]
|
tool = api.tools[0]
|
||||||
assert tool.name == "test_intent"
|
assert tool.name == "test_intent"
|
||||||
@ -223,7 +221,7 @@ async def test_assist_api_prompt(
|
|||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
assert await async_setup_component(hass, "intent", {})
|
assert await async_setup_component(hass, "intent", {})
|
||||||
context = Context()
|
context = Context()
|
||||||
tool_context = llm.ToolContext(
|
llm_context = llm.LLMContext(
|
||||||
platform="test_platform",
|
platform="test_platform",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="test_text",
|
user_prompt="test_text",
|
||||||
@ -231,7 +229,7 @@ async def test_assist_api_prompt(
|
|||||||
assistant="conversation",
|
assistant="conversation",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
"Only if the user wants to control a device, tell them to expose entities to their "
|
"Only if the user wants to control a device, tell them to expose entities to their "
|
||||||
"voice assistant in Home Assistant."
|
"voice assistant in Home Assistant."
|
||||||
@ -360,7 +358,7 @@ async def test_assist_api_prompt(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant)
|
exposed_entities = llm._get_exposed_entities(hass, llm_context.assistant)
|
||||||
assert exposed_entities == {
|
assert exposed_entities == {
|
||||||
"light.1": {
|
"light.1": {
|
||||||
"areas": "Test Area 2",
|
"areas": "Test Area 2",
|
||||||
@ -435,7 +433,7 @@ async def test_assist_api_prompt(
|
|||||||
"When a user asks to turn on all devices of a specific type, "
|
"When a user asks to turn on all devices of a specific type, "
|
||||||
"ask user to specify an area, unless there is only one device of that type."
|
"ask user to specify an area, unless there is only one device of that type."
|
||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
@ -444,12 +442,12 @@ async def test_assist_api_prompt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fake that request is made from a specific device ID with an area
|
# Fake that request is made from a specific device ID with an area
|
||||||
tool_context.device_id = device.id
|
llm_context.device_id = device.id
|
||||||
area_prompt = (
|
area_prompt = (
|
||||||
"You are in area Test Area and all generic commands like 'turn on the lights' "
|
"You are in area Test Area and all generic commands like 'turn on the lights' "
|
||||||
"should target this area."
|
"should target this area."
|
||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
@ -464,7 +462,7 @@ async def test_assist_api_prompt(
|
|||||||
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
|
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
|
||||||
"should target this area."
|
"should target this area."
|
||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
@ -475,7 +473,7 @@ async def test_assist_api_prompt(
|
|||||||
# Register device for timers
|
# Register device for timers
|
||||||
async_register_timer_handler(hass, device.id, lambda *args: None)
|
async_register_timer_handler(hass, device.id, lambda *args: None)
|
||||||
|
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
# The no_timer_prompt is gone
|
# The no_timer_prompt is gone
|
||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user