Rename llm.ToolContext to llm.LLMContext (#118566)

This commit is contained in:
Paulus Schoutsen 2024-05-31 11:11:24 -04:00 committed by GitHub
parent 6656f7d6b9
commit 6dd01dbff7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 62 additions and 64 deletions

View File

@ -169,7 +169,7 @@ class GoogleGenerativeAIConversationEntity(
llm_api = await llm.async_get_api( llm_api = await llm.async_get_api(
self.hass, self.hass,
self.entry.options[CONF_LLM_HASS_API], self.entry.options[CONF_LLM_HASS_API],
llm.ToolContext( llm.LLMContext(
platform=DOMAIN, platform=DOMAIN,
context=user_input.context, context=user_input.context,
user_prompt=user_input.text, user_prompt=user_input.text,

View File

@ -119,7 +119,7 @@ class OpenAIConversationEntity(
llm_api = await llm.async_get_api( llm_api = await llm.async_get_api(
self.hass, self.hass,
options[CONF_LLM_HASS_API], options[CONF_LLM_HASS_API],
llm.ToolContext( llm.LLMContext(
platform=DOMAIN, platform=DOMAIN,
context=user_input.context, context=user_input.context,
user_prompt=user_input.text, user_prompt=user_input.text,

View File

@ -71,7 +71,7 @@ def async_register_api(hass: HomeAssistant, api: API) -> None:
async def async_get_api( async def async_get_api(
hass: HomeAssistant, api_id: str, tool_context: ToolContext hass: HomeAssistant, api_id: str, llm_context: LLMContext
) -> APIInstance: ) -> APIInstance:
"""Get an API.""" """Get an API."""
apis = _async_get_apis(hass) apis = _async_get_apis(hass)
@ -79,7 +79,7 @@ async def async_get_api(
if api_id not in apis: if api_id not in apis:
raise HomeAssistantError(f"API {api_id} not found") raise HomeAssistantError(f"API {api_id} not found")
return await apis[api_id].async_get_api_instance(tool_context) return await apis[api_id].async_get_api_instance(llm_context)
@callback @callback
@ -89,7 +89,7 @@ def async_get_apis(hass: HomeAssistant) -> list[API]:
@dataclass(slots=True) @dataclass(slots=True)
class ToolContext: class LLMContext:
"""Tool input to be processed.""" """Tool input to be processed."""
platform: str platform: str
@ -117,7 +117,7 @@ class Tool:
@abstractmethod @abstractmethod
async def async_call( async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType: ) -> JsonObjectType:
"""Call the tool.""" """Call the tool."""
raise NotImplementedError raise NotImplementedError
@ -133,7 +133,7 @@ class APIInstance:
api: API api: API
api_prompt: str api_prompt: str
tool_context: ToolContext llm_context: LLMContext
tools: list[Tool] tools: list[Tool]
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
@ -149,7 +149,7 @@ class APIInstance:
else: else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
return await tool.async_call(self.api.hass, tool_input, self.tool_context) return await tool.async_call(self.api.hass, tool_input, self.llm_context)
@dataclass(slots=True, kw_only=True) @dataclass(slots=True, kw_only=True)
@ -161,7 +161,7 @@ class API(ABC):
name: str name: str
@abstractmethod @abstractmethod
async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance: async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API.""" """Return the instance of the API."""
raise NotImplementedError raise NotImplementedError
@ -182,20 +182,20 @@ class IntentTool(Tool):
self.parameters = vol.Schema(slot_schema) self.parameters = vol.Schema(slot_schema)
async def async_call( async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType: ) -> JsonObjectType:
"""Handle the intent.""" """Handle the intent."""
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()} slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
intent_response = await intent.async_handle( intent_response = await intent.async_handle(
hass=hass, hass=hass,
platform=tool_context.platform, platform=llm_context.platform,
intent_type=self.name, intent_type=self.name,
slots=slots, slots=slots,
text_input=tool_context.user_prompt, text_input=llm_context.user_prompt,
context=tool_context.context, context=llm_context.context,
language=tool_context.language, language=llm_context.language,
assistant=tool_context.assistant, assistant=llm_context.assistant,
device_id=tool_context.device_id, device_id=llm_context.device_id,
) )
response = intent_response.as_dict() response = intent_response.as_dict()
del response["language"] del response["language"]
@ -224,25 +224,25 @@ class AssistAPI(API):
name="Assist", name="Assist",
) )
async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance: async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API.""" """Return the instance of the API."""
if tool_context.assistant: if llm_context.assistant:
exposed_entities: dict | None = _get_exposed_entities( exposed_entities: dict | None = _get_exposed_entities(
self.hass, tool_context.assistant self.hass, llm_context.assistant
) )
else: else:
exposed_entities = None exposed_entities = None
return APIInstance( return APIInstance(
api=self, api=self,
api_prompt=self._async_get_api_prompt(tool_context, exposed_entities), api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
tool_context=tool_context, llm_context=llm_context,
tools=self._async_get_tools(tool_context, exposed_entities), tools=self._async_get_tools(llm_context, exposed_entities),
) )
@callback @callback
def _async_get_api_prompt( def _async_get_api_prompt(
self, tool_context: ToolContext, exposed_entities: dict | None self, llm_context: LLMContext, exposed_entities: dict | None
) -> str: ) -> str:
"""Return the prompt for the API.""" """Return the prompt for the API."""
if not exposed_entities: if not exposed_entities:
@ -263,9 +263,9 @@ class AssistAPI(API):
] ]
area: ar.AreaEntry | None = None area: ar.AreaEntry | None = None
floor: fr.FloorEntry | None = None floor: fr.FloorEntry | None = None
if tool_context.device_id: if llm_context.device_id:
device_reg = dr.async_get(self.hass) device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_context.device_id) device = device_reg.async_get(llm_context.device_id)
if device: if device:
area_reg = ar.async_get(self.hass) area_reg = ar.async_get(self.hass)
@ -286,8 +286,8 @@ class AssistAPI(API):
"ask user to specify an area, unless there is only one device of that type." "ask user to specify an area, unless there is only one device of that type."
) )
if not tool_context.device_id or not async_device_supports_timers( if not llm_context.device_id or not async_device_supports_timers(
self.hass, tool_context.device_id self.hass, llm_context.device_id
): ):
prompt.append("This device does not support timers.") prompt.append("This device does not support timers.")
@ -301,12 +301,12 @@ class AssistAPI(API):
@callback @callback
def _async_get_tools( def _async_get_tools(
self, tool_context: ToolContext, exposed_entities: dict | None self, llm_context: LLMContext, exposed_entities: dict | None
) -> list[Tool]: ) -> list[Tool]:
"""Return a list of LLM tools.""" """Return a list of LLM tools."""
ignore_intents = self.IGNORE_INTENTS ignore_intents = self.IGNORE_INTENTS
if not tool_context.device_id or not async_device_supports_timers( if not llm_context.device_id or not async_device_supports_timers(
self.hass, tool_context.device_id self.hass, llm_context.device_id
): ):
ignore_intents = ignore_intents | { ignore_intents = ignore_intents | {
intent.INTENT_START_TIMER, intent.INTENT_START_TIMER,

View File

@ -231,7 +231,7 @@ async def test_function_call(
"param2": "param2's value", "param2": "param2's value",
}, },
), ),
llm.ToolContext( llm.LLMContext(
platform="google_generative_ai_conversation", platform="google_generative_ai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",
@ -330,7 +330,7 @@ async def test_function_exception(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": 1}, tool_args={"param1": 1},
), ),
llm.ToolContext( llm.LLMContext(
platform="google_generative_ai_conversation", platform="google_generative_ai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",

View File

@ -192,7 +192,7 @@ async def test_function_call(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),
llm.ToolContext( llm.LLMContext(
platform="openai_conversation", platform="openai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",
@ -324,7 +324,7 @@ async def test_function_exception(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
), ),
llm.ToolContext( llm.LLMContext(
platform="openai_conversation", platform="openai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",

View File

@ -24,9 +24,9 @@ from tests.common import MockConfigEntry
@pytest.fixture @pytest.fixture
def tool_input_context() -> llm.ToolContext: def llm_context() -> llm.LLMContext:
"""Return tool input context.""" """Return tool input context."""
return llm.ToolContext( return llm.LLMContext(
platform="", platform="",
context=None, context=None,
user_prompt=None, user_prompt=None,
@ -37,29 +37,27 @@ def tool_input_context() -> llm.ToolContext:
async def test_get_api_no_existing( async def test_get_api_no_existing(
hass: HomeAssistant, tool_input_context: llm.ToolContext hass: HomeAssistant, llm_context: llm.LLMContext
) -> None: ) -> None:
"""Test getting an llm api where no config exists.""" """Test getting an llm api where no config exists."""
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await llm.async_get_api(hass, "non-existing", tool_input_context) await llm.async_get_api(hass, "non-existing", llm_context)
async def test_register_api( async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
hass: HomeAssistant, tool_input_context: llm.ToolContext
) -> None:
"""Test registering an llm api.""" """Test registering an llm api."""
class MyAPI(llm.API): class MyAPI(llm.API):
async def async_get_api_instance( async def async_get_api_instance(
self, tool_input: llm.ToolInput self, tool_context: llm.ToolInput
) -> llm.APIInstance: ) -> llm.APIInstance:
"""Return a list of tools.""" """Return a list of tools."""
return llm.APIInstance(self, "", [], tool_input_context) return llm.APIInstance(self, "", [], llm_context)
api = MyAPI(hass=hass, id="test", name="Test") api = MyAPI(hass=hass, id="test", name="Test")
llm.async_register_api(hass, api) llm.async_register_api(hass, api)
instance = await llm.async_get_api(hass, "test", tool_input_context) instance = await llm.async_get_api(hass, "test", llm_context)
assert instance.api is api assert instance.api is api
assert api in llm.async_get_apis(hass) assert api in llm.async_get_apis(hass)
@ -68,10 +66,10 @@ async def test_register_api(
async def test_call_tool_no_existing( async def test_call_tool_no_existing(
hass: HomeAssistant, tool_input_context: llm.ToolContext hass: HomeAssistant, llm_context: llm.LLMContext
) -> None: ) -> None:
"""Test calling an llm tool where no config exists.""" """Test calling an llm tool where no config exists."""
instance = await llm.async_get_api(hass, "assist", tool_input_context) instance = await llm.async_get_api(hass, "assist", llm_context)
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await instance.async_call_tool( await instance.async_call_tool(
llm.ToolInput("test_tool", {}), llm.ToolInput("test_tool", {}),
@ -93,7 +91,7 @@ async def test_assist_api(
).write_unavailable_state(hass) ).write_unavailable_state(hass)
test_context = Context() test_context = Context()
tool_context = llm.ToolContext( llm_context = llm.LLMContext(
platform="test_platform", platform="test_platform",
context=test_context, context=test_context,
user_prompt="test_text", user_prompt="test_text",
@ -116,19 +114,19 @@ async def test_assist_api(
intent.async_register(hass, intent_handler) intent.async_register(hass, intent_handler)
assert len(llm.async_get_apis(hass)) == 1 assert len(llm.async_get_apis(hass)) == 1
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert len(api.tools) == 0 assert len(api.tools) == 0
# Match all # Match all
intent_handler.platforms = None intent_handler.platforms = None
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert len(api.tools) == 1 assert len(api.tools) == 1
# Match specific domain # Match specific domain
intent_handler.platforms = {"light"} intent_handler.platforms = {"light"}
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert len(api.tools) == 1 assert len(api.tools) == 1
tool = api.tools[0] tool = api.tools[0]
assert tool.name == "test_intent" assert tool.name == "test_intent"
@ -176,25 +174,25 @@ async def test_assist_api(
async def test_assist_api_get_timer_tools( async def test_assist_api_get_timer_tools(
hass: HomeAssistant, tool_input_context: llm.ToolContext hass: HomeAssistant, llm_context: llm.LLMContext
) -> None: ) -> None:
"""Test getting timer tools with Assist API.""" """Test getting timer tools with Assist API."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {}) assert await async_setup_component(hass, "intent", {})
api = await llm.async_get_api(hass, "assist", tool_input_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert "HassStartTimer" not in [tool.name for tool in api.tools] assert "HassStartTimer" not in [tool.name for tool in api.tools]
tool_input_context.device_id = "test_device" llm_context.device_id = "test_device"
async_register_timer_handler(hass, "test_device", lambda *args: None) async_register_timer_handler(hass, "test_device", lambda *args: None)
api = await llm.async_get_api(hass, "assist", tool_input_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert "HassStartTimer" in [tool.name for tool in api.tools] assert "HassStartTimer" in [tool.name for tool in api.tools]
async def test_assist_api_description( async def test_assist_api_description(
hass: HomeAssistant, tool_input_context: llm.ToolContext hass: HomeAssistant, llm_context: llm.LLMContext
) -> None: ) -> None:
"""Test intent description with Assist API.""" """Test intent description with Assist API."""
@ -205,7 +203,7 @@ async def test_assist_api_description(
intent.async_register(hass, MyIntentHandler()) intent.async_register(hass, MyIntentHandler())
assert len(llm.async_get_apis(hass)) == 1 assert len(llm.async_get_apis(hass)) == 1
api = await llm.async_get_api(hass, "assist", tool_input_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert len(api.tools) == 1 assert len(api.tools) == 1
tool = api.tools[0] tool = api.tools[0]
assert tool.name == "test_intent" assert tool.name == "test_intent"
@ -223,7 +221,7 @@ async def test_assist_api_prompt(
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {}) assert await async_setup_component(hass, "intent", {})
context = Context() context = Context()
tool_context = llm.ToolContext( llm_context = llm.LLMContext(
platform="test_platform", platform="test_platform",
context=context, context=context,
user_prompt="test_text", user_prompt="test_text",
@ -231,7 +229,7 @@ async def test_assist_api_prompt(
assistant="conversation", assistant="conversation",
device_id=None, device_id=None,
) )
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == ( assert api.api_prompt == (
"Only if the user wants to control a device, tell them to expose entities to their " "Only if the user wants to control a device, tell them to expose entities to their "
"voice assistant in Home Assistant." "voice assistant in Home Assistant."
@ -360,7 +358,7 @@ async def test_assist_api_prompt(
) )
) )
exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant) exposed_entities = llm._get_exposed_entities(hass, llm_context.assistant)
assert exposed_entities == { assert exposed_entities == {
"light.1": { "light.1": {
"areas": "Test Area 2", "areas": "Test Area 2",
@ -435,7 +433,7 @@ async def test_assist_api_prompt(
"When a user asks to turn on all devices of a specific type, " "When a user asks to turn on all devices of a specific type, "
"ask user to specify an area, unless there is only one device of that type." "ask user to specify an area, unless there is only one device of that type."
) )
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == ( assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
@ -444,12 +442,12 @@ async def test_assist_api_prompt(
) )
# Fake that request is made from a specific device ID with an area # Fake that request is made from a specific device ID with an area
tool_context.device_id = device.id llm_context.device_id = device.id
area_prompt = ( area_prompt = (
"You are in area Test Area and all generic commands like 'turn on the lights' " "You are in area Test Area and all generic commands like 'turn on the lights' "
"should target this area." "should target this area."
) )
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == ( assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
@ -464,7 +462,7 @@ async def test_assist_api_prompt(
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' " "You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
"should target this area." "should target this area."
) )
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == ( assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
@ -475,7 +473,7 @@ async def test_assist_api_prompt(
# Register device for timers # Register device for timers
async_register_timer_handler(hass, device.id, lambda *args: None) async_register_timer_handler(hass, device.id, lambda *args: None)
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", llm_context)
# The no_timer_prompt is gone # The no_timer_prompt is gone
assert api.api_prompt == ( assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}