diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index e7aaabb912d..d722403a0be 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -169,7 +169,7 @@ class GoogleGenerativeAIConversationEntity( llm_api = await llm.async_get_api( self.hass, self.entry.options[CONF_LLM_HASS_API], - llm.ToolContext( + llm.LLMContext( platform=DOMAIN, context=user_input.context, user_prompt=user_input.text, diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index afc5396e0ba..26acfda979d 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -119,7 +119,7 @@ class OpenAIConversationEntity( llm_api = await llm.async_get_api( self.hass, options[CONF_LLM_HASS_API], - llm.ToolContext( + llm.LLMContext( platform=DOMAIN, context=user_input.context, user_prompt=user_input.text, diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index b4b5f9137c4..dd380795227 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -71,7 +71,7 @@ def async_register_api(hass: HomeAssistant, api: API) -> None: async def async_get_api( - hass: HomeAssistant, api_id: str, tool_context: ToolContext + hass: HomeAssistant, api_id: str, llm_context: LLMContext ) -> APIInstance: """Get an API.""" apis = _async_get_apis(hass) @@ -79,7 +79,7 @@ async def async_get_api( if api_id not in apis: raise HomeAssistantError(f"API {api_id} not found") - return await apis[api_id].async_get_api_instance(tool_context) + return await apis[api_id].async_get_api_instance(llm_context) @callback @@ -89,7 +89,7 @@ def async_get_apis(hass: HomeAssistant) -> list[API]: @dataclass(slots=True) -class ToolContext: +class LLMContext: """Tool input to be processed.""" platform: str @@ -117,7 +117,7 @@ class Tool: @abstractmethod async def async_call( - self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext + self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Call the tool.""" raise NotImplementedError @@ -133,7 +133,7 @@ class APIInstance: api: API api_prompt: str - tool_context: ToolContext + llm_context: LLMContext tools: list[Tool] async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: @@ -149,7 +149,7 @@ class APIInstance: else: raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') - return await tool.async_call(self.api.hass, tool_input, self.tool_context) + return await tool.async_call(self.api.hass, tool_input, self.llm_context) @dataclass(slots=True, kw_only=True) @@ -161,7 +161,7 @@ class API(ABC): name: str @abstractmethod - async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance: + async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance: """Return the instance of the API.""" raise NotImplementedError @@ -182,20 +182,20 @@ class IntentTool(Tool): self.parameters = vol.Schema(slot_schema) async def async_call( - self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext + self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext ) -> JsonObjectType: """Handle the intent.""" slots = {key: {"value": val} for key, val in tool_input.tool_args.items()} intent_response = await intent.async_handle( hass=hass, - platform=tool_context.platform, + platform=llm_context.platform, intent_type=self.name, slots=slots, - text_input=tool_context.user_prompt, - context=tool_context.context, - language=tool_context.language, - assistant=tool_context.assistant, - device_id=tool_context.device_id, + text_input=llm_context.user_prompt, + context=llm_context.context, + language=llm_context.language, + assistant=llm_context.assistant, + device_id=llm_context.device_id, ) response = intent_response.as_dict() del response["language"] @@ -224,25 +224,25 @@ class AssistAPI(API): name="Assist", ) - async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance: + async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance: """Return the instance of the API.""" - if tool_context.assistant: + if llm_context.assistant: exposed_entities: dict | None = _get_exposed_entities( - self.hass, tool_context.assistant + self.hass, llm_context.assistant ) else: exposed_entities = None return APIInstance( api=self, - api_prompt=self._async_get_api_prompt(tool_context, exposed_entities), - tool_context=tool_context, - tools=self._async_get_tools(tool_context, exposed_entities), + api_prompt=self._async_get_api_prompt(llm_context, exposed_entities), + llm_context=llm_context, + tools=self._async_get_tools(llm_context, exposed_entities), ) @callback def _async_get_api_prompt( - self, tool_context: ToolContext, exposed_entities: dict | None + self, llm_context: LLMContext, exposed_entities: dict | None ) -> str: """Return the prompt for the API.""" if not exposed_entities: @@ -263,9 +263,9 @@ class AssistAPI(API): ] area: ar.AreaEntry | None = None floor: fr.FloorEntry | None = None - if tool_context.device_id: + if llm_context.device_id: device_reg = dr.async_get(self.hass) - device = device_reg.async_get(tool_context.device_id) + device = device_reg.async_get(llm_context.device_id) if device: area_reg = ar.async_get(self.hass) @@ -286,8 +286,8 @@ class AssistAPI(API): "ask user to specify an area, unless there is only one device of that type." ) - if not tool_context.device_id or not async_device_supports_timers( - self.hass, tool_context.device_id + if not llm_context.device_id or not async_device_supports_timers( + self.hass, llm_context.device_id ): prompt.append("This device does not support timers.") @@ -301,12 +301,12 @@ class AssistAPI(API): @callback def _async_get_tools( - self, tool_context: ToolContext, exposed_entities: dict | None + self, llm_context: LLMContext, exposed_entities: dict | None ) -> list[Tool]: """Return a list of LLM tools.""" ignore_intents = self.IGNORE_INTENTS - if not tool_context.device_id or not async_device_supports_timers( - self.hass, tool_context.device_id + if not llm_context.device_id or not async_device_supports_timers( + self.hass, llm_context.device_id ): ignore_intents = ignore_intents | { intent.INTENT_START_TIMER, diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index b282895baef..19a855aa17f 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -231,7 +231,7 @@ async def test_function_call( "param2": "param2's value", }, ), - llm.ToolContext( + llm.LLMContext( platform="google_generative_ai_conversation", context=context, user_prompt="Please call the test function", @@ -330,7 +330,7 @@ async def test_function_exception( tool_name="test_tool", tool_args={"param1": 1}, ), - llm.ToolContext( + llm.LLMContext( platform="google_generative_ai_conversation", context=context, user_prompt="Please call the test function", diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 4d16973ddfc..10829db7575 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -192,7 +192,7 @@ async def test_function_call( tool_name="test_tool", tool_args={"param1": "test_value"}, ), - llm.ToolContext( + llm.LLMContext( platform="openai_conversation", context=context, user_prompt="Please call the test function", @@ -324,7 +324,7 @@ async def test_function_exception( tool_name="test_tool", tool_args={"param1": "test_value"}, ), - llm.ToolContext( + llm.LLMContext( platform="openai_conversation", context=context, user_prompt="Please call the test function", diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 355abf2fe5d..9c07295dec7 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -24,9 +24,9 @@ from tests.common import MockConfigEntry @pytest.fixture -def tool_input_context() -> llm.ToolContext: +def llm_context() -> llm.LLMContext: """Return tool input context.""" - return llm.ToolContext( + return llm.LLMContext( platform="", context=None, user_prompt=None, @@ -37,29 +37,27 @@ def tool_input_context() -> llm.ToolContext: async def test_get_api_no_existing( - hass: HomeAssistant, tool_input_context: llm.ToolContext + hass: HomeAssistant, llm_context: llm.LLMContext ) -> None: """Test getting an llm api where no config exists.""" with pytest.raises(HomeAssistantError): - await llm.async_get_api(hass, "non-existing", tool_input_context) + await llm.async_get_api(hass, "non-existing", llm_context) -async def test_register_api( - hass: HomeAssistant, tool_input_context: llm.ToolContext -) -> None: +async def test_register_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None: """Test registering an llm api.""" class MyAPI(llm.API): async def async_get_api_instance( - self, tool_input: llm.ToolInput + self, tool_context: llm.ToolInput ) -> llm.APIInstance: """Return a list of tools.""" - return llm.APIInstance(self, "", [], tool_input_context) + return llm.APIInstance(self, "", [], llm_context) api = MyAPI(hass=hass, id="test", name="Test") llm.async_register_api(hass, api) - instance = await llm.async_get_api(hass, "test", tool_input_context) + instance = await llm.async_get_api(hass, "test", llm_context) assert instance.api is api assert api in llm.async_get_apis(hass) @@ -68,10 +66,10 @@ async def test_register_api( async def test_call_tool_no_existing( - hass: HomeAssistant, tool_input_context: llm.ToolContext + hass: HomeAssistant, llm_context: llm.LLMContext ) -> None: """Test calling an llm tool where no config exists.""" - instance = await llm.async_get_api(hass, "assist", tool_input_context) + instance = await llm.async_get_api(hass, "assist", llm_context) with pytest.raises(HomeAssistantError): await instance.async_call_tool( llm.ToolInput("test_tool", {}), @@ -93,7 +91,7 @@ async def test_assist_api( ).write_unavailable_state(hass) test_context = Context() - tool_context = llm.ToolContext( + llm_context = llm.LLMContext( platform="test_platform", context=test_context, user_prompt="test_text", @@ -116,19 +114,19 @@ async def test_assist_api( intent.async_register(hass, intent_handler) assert len(llm.async_get_apis(hass)) == 1 - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert len(api.tools) == 0 # Match all intent_handler.platforms = None - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert len(api.tools) == 1 # Match specific domain intent_handler.platforms = {"light"} - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert len(api.tools) == 1 tool = api.tools[0] assert tool.name == "test_intent" @@ -176,25 +174,25 @@ async def test_assist_api( async def test_assist_api_get_timer_tools( - hass: HomeAssistant, tool_input_context: llm.ToolContext + hass: HomeAssistant, llm_context: llm.LLMContext ) -> None: """Test getting timer tools with Assist API.""" assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "intent", {}) - api = await llm.async_get_api(hass, "assist", tool_input_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert "HassStartTimer" not in [tool.name for tool in api.tools] - tool_input_context.device_id = "test_device" + llm_context.device_id = "test_device" async_register_timer_handler(hass, "test_device", lambda *args: None) - api = await llm.async_get_api(hass, "assist", tool_input_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert "HassStartTimer" in [tool.name for tool in api.tools] async def test_assist_api_description( - hass: HomeAssistant, tool_input_context: llm.ToolContext + hass: HomeAssistant, llm_context: llm.LLMContext ) -> None: """Test intent description with Assist API.""" @@ -205,7 +203,7 @@ async def test_assist_api_description( intent.async_register(hass, MyIntentHandler()) assert len(llm.async_get_apis(hass)) == 1 - api = await llm.async_get_api(hass, "assist", tool_input_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert len(api.tools) == 1 tool = api.tools[0] assert tool.name == "test_intent" @@ -223,7 +221,7 @@ async def test_assist_api_prompt( assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "intent", {}) context = Context() - tool_context = llm.ToolContext( + llm_context = llm.LLMContext( platform="test_platform", context=context, user_prompt="test_text", @@ -231,7 +229,7 @@ async def test_assist_api_prompt( assistant="conversation", device_id=None, ) - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert api.api_prompt == ( "Only if the user wants to control a device, tell them to expose entities to their " "voice assistant in Home Assistant." @@ -360,7 +358,7 @@ async def test_assist_api_prompt( ) ) - exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant) + exposed_entities = llm._get_exposed_entities(hass, llm_context.assistant) assert exposed_entities == { "light.1": { "areas": "Test Area 2", @@ -435,7 +433,7 @@ async def test_assist_api_prompt( "When a user asks to turn on all devices of a specific type, " "ask user to specify an area, unless there is only one device of that type." ) - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} @@ -444,12 +442,12 @@ async def test_assist_api_prompt( ) # Fake that request is made from a specific device ID with an area - tool_context.device_id = device.id + llm_context.device_id = device.id area_prompt = ( "You are in area Test Area and all generic commands like 'turn on the lights' " "should target this area." ) - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} @@ -464,7 +462,7 @@ async def test_assist_api_prompt( "You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' " "should target this area." ) - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} @@ -475,7 +473,7 @@ async def test_assist_api_prompt( # Register device for timers async_register_timer_handler(hass, device.id, lambda *args: None) - api = await llm.async_get_api(hass, "assist", tool_context) + api = await llm.async_get_api(hass, "assist", llm_context) # The no_timer_prompt is gone assert api.api_prompt == ( f"""{first_part_prompt}