From 7554ca9460af51ff462a1f4189f37790f34b81cd Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 24 May 2024 16:04:48 -0400 Subject: [PATCH] Allow llm API to render dynamic template prompt (#118055) * Allow llm API to render dynamic template prompt * Make rendering api prompt async so it can become a RAG * Fix test --- .../config_flow.py | 124 +++++++--------- .../const.py | 19 +-- .../conversation.py | 46 +++--- .../strings.json | 8 +- .../components/openai_conversation/const.py | 18 +-- .../openai_conversation/conversation.py | 50 ++++--- homeassistant/helpers/llm.py | 11 +- .../snapshots/test_conversation.ambr | 66 +-------- .../test_config_flow.py | 9 +- .../openai_conversation/test_conversation.py | 139 +----------------- tests/helpers/test_llm.py | 6 +- 11 files changed, 137 insertions(+), 359 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index 2f9040344b3..3845d7f4e92 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -36,7 +36,6 @@ from .const import ( CONF_PROMPT, CONF_RECOMMENDED, CONF_TEMPERATURE, - CONF_TONE_PROMPT, CONF_TOP_K, CONF_TOP_P, DEFAULT_PROMPT, @@ -59,7 +58,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema( RECOMMENDED_OPTIONS = { CONF_RECOMMENDED: True, CONF_LLM_HASS_API: llm.LLM_API_ASSIST, - CONF_TONE_PROMPT: "", + CONF_PROMPT: "", } @@ -142,16 +141,11 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): # Re-render the options again, now with the recommended options shown/hidden self.last_rendered_recommended = user_input[CONF_RECOMMENDED] - # If we switch to not recommended, generate used prompt. - if user_input[CONF_RECOMMENDED]: - options = RECOMMENDED_OPTIONS - else: - options = { - CONF_RECOMMENDED: False, - CONF_PROMPT: DEFAULT_PROMPT - + "\n" - + user_input.get(CONF_TONE_PROMPT, ""), - } + options = { + CONF_RECOMMENDED: user_input[CONF_RECOMMENDED], + CONF_PROMPT: user_input[CONF_PROMPT], + CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], + } schema = await google_generative_ai_config_option_schema(self.hass, options) return self.async_show_form( @@ -179,22 +173,24 @@ async def google_generative_ai_config_option_schema( for api in llm.async_get_apis(hass) ) + schema = { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT)}, + default=DEFAULT_PROMPT, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, + } + if options.get(CONF_RECOMMENDED): - return { - vol.Required( - CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) - ): bool, - vol.Optional( - CONF_TONE_PROMPT, - description={"suggested_value": options.get(CONF_TONE_PROMPT)}, - default="", - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - default="none", - ): SelectSelector(SelectSelectorConfig(options=hass_apis)), - } + return schema api_models = await hass.async_add_executor_job(partial(genai.list_models)) @@ -211,45 +207,35 @@ async def google_generative_ai_config_option_schema( ) ] - return { - vol.Required( - CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) - ): bool, - vol.Optional( - CONF_CHAT_MODEL, - description={"suggested_value": options.get(CONF_CHAT_MODEL)}, - default=RECOMMENDED_CHAT_MODEL, - ): SelectSelector( - SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models) - ), - vol.Optional( - CONF_PROMPT, - description={"suggested_value": options.get(CONF_PROMPT)}, - default=DEFAULT_PROMPT, - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - default="none", - ): SelectSelector(SelectSelectorConfig(options=hass_apis)), - vol.Optional( - CONF_TEMPERATURE, - description={"suggested_value": options.get(CONF_TEMPERATURE)}, - default=RECOMMENDED_TEMPERATURE, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_TOP_P, - description={"suggested_value": options.get(CONF_TOP_P)}, - default=RECOMMENDED_TOP_P, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_TOP_K, - description={"suggested_value": options.get(CONF_TOP_K)}, - default=RECOMMENDED_TOP_K, - ): int, - vol.Optional( - CONF_MAX_TOKENS, - description={"suggested_value": options.get(CONF_MAX_TOKENS)}, - default=RECOMMENDED_MAX_TOKENS, - ): int, - } + schema.update( + { + vol.Optional( + CONF_CHAT_MODEL, + description={"suggested_value": options.get(CONF_CHAT_MODEL)}, + default=RECOMMENDED_CHAT_MODEL, + ): SelectSelector( + SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models) + ), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, + default=RECOMMENDED_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options.get(CONF_TOP_P)}, + default=RECOMMENDED_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TOP_K, + description={"suggested_value": options.get(CONF_TOP_K)}, + default=RECOMMENDED_TOP_K, + ): int, + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, + default=RECOMMENDED_MAX_TOKENS, + ): int, + } + ) + return schema diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 53a1e2a74a9..9a16a31abd7 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -5,24 +5,7 @@ import logging DOMAIN = "google_generative_ai_conversation" LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" -CONF_TONE_PROMPT = "tone_prompt" -DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. - -An overview of the areas and the devices in this smart home: -{%- for area in areas() %} - {%- set area_info = namespace(printed=false) %} - {%- for device in area_devices(area) -%} - {%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %} - {%- if not area_info.printed %} - -{{ area_name(area) }}: - {%- set area_info.printed = true %} - {%- endif %} -- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %} - {%- endif %} - {%- endfor %} -{%- endfor %} -""" +DEFAULT_PROMPT = "Answer in plain text. Keep it simple and to the point." CONF_RECOMMENDED = "recommended" CONF_CHAT_MODEL = "chat_model" diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index b68ab39d53b..2bc79ac8dde 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -25,7 +25,6 @@ from .const import ( CONF_MAX_TOKENS, CONF_PROMPT, CONF_TEMPERATURE, - CONF_TONE_PROMPT, CONF_TOP_K, CONF_TOP_P, DEFAULT_PROMPT, @@ -179,12 +178,32 @@ class GoogleGenerativeAIConversationEntity( conversation_id = ulid.ulid_now() messages = [{}, {}] - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT): - raw_prompt += "\n" + tone_prompt - try: - prompt = self._async_generate_prompt(raw_prompt, llm_api) + prompt = template.Template( + self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass + ).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ) + + if llm_api: + empty_tool_input = llm.ToolInput( + tool_name="", + tool_args={}, + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + + prompt = ( + await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt + ) + except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) intent_response.async_set_error( @@ -271,18 +290,3 @@ class GoogleGenerativeAIConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) - - def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str: - """Generate a prompt for the user.""" - raw_prompt += "\n" - if llm_api: - raw_prompt += llm_api.prompt_template - else: - raw_prompt += llm.PROMPT_NO_API_CONFIGURED - - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index 8a961c9e3d3..f35561a6aa6 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -18,9 +18,8 @@ "step": { "init": { "data": { - "recommended": "Recommended settings", - "prompt": "Prompt", - "tone_prompt": "Tone", + "recommended": "Recommended model settings", + "prompt": "Instructions", "chat_model": "[%key:common::generic::model%]", "temperature": "Temperature", "top_p": "Top P", @@ -29,8 +28,7 @@ "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" }, "data_description": { - "prompt": "Extra data to provide to the LLM. This can be a template.", - "tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template." + "prompt": "Instruct how the LLM should respond. This can be a template." } } } diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index c50b66c1320..27ef86bf918 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -5,23 +5,7 @@ import logging DOMAIN = "openai_conversation" LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" -DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. - -An overview of the areas and the devices in this smart home: -{%- for area in areas() %} - {%- set area_info = namespace(printed=false) %} - {%- for device in area_devices(area) -%} - {%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %} - {%- if not area_info.printed %} - -{{ area_name(area) }}: - {%- set area_info.printed = true %} - {%- endif %} -- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %} - {%- endif %} - {%- endfor %} -{%- endfor %} -""" +DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point.""" CONF_CHAT_MODEL = "chat_model" DEFAULT_CHAT_MODEL = "gpt-4o" CONF_MAX_TOKENS = "max_tokens" diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index b7219aad608..7fe4ef6ac04 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -110,7 +110,6 @@ class OpenAIConversationEntity( ) tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) @@ -122,10 +121,33 @@ class OpenAIConversationEntity( else: conversation_id = ulid.ulid_now() try: - prompt = self._async_generate_prompt( - raw_prompt, - llm_api, + prompt = template.Template( + self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass + ).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, ) + + if llm_api: + empty_tool_input = llm.ToolInput( + tool_name="", + tool_args={}, + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + + prompt = ( + await llm_api.async_get_api_prompt(empty_tool_input) + + "\n" + + prompt + ) + except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) @@ -136,6 +158,7 @@ class OpenAIConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) + messages = [{"role": "system", "content": prompt}] messages.append({"role": "user", "content": user_input.text}) @@ -213,22 +236,3 @@ class OpenAIConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) - - def _async_generate_prompt( - self, - raw_prompt: str, - llm_api: llm.API | None, - ) -> str: - """Generate a prompt for the user.""" - raw_prompt += "\n" - if llm_api: - raw_prompt += llm_api.prompt_template - else: - raw_prompt += llm.PROMPT_NO_API_CONFIGURED - - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 081ac39e9d9..ec426b350d9 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -102,7 +102,11 @@ class API(ABC): hass: HomeAssistant id: str name: str - prompt_template: str + + @abstractmethod + async def async_get_api_prompt(self, tool_input: ToolInput) -> str: + """Return the prompt for the API.""" + raise NotImplementedError @abstractmethod @callback @@ -183,9 +187,12 @@ class AssistAPI(API): hass=hass, id=LLM_API_ASSIST, name="Assist", - prompt_template="Call the intent tools to control the system. Just pass the name to the intent.", ) + async def async_get_api_prompt(self, tool_input: ToolInput) -> str: + """Return the prompt for the API.""" + return "Call the intent tools to control Home Assistant. Just pass the name to the intent." + @callback def async_get_tools(self) -> list[Tool]: """Return a list of LLM tools.""" diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index 24342bc0b1e..fe44c6a1608 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -23,22 +23,7 @@ dict({ 'history': list([ dict({ - 'parts': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. - ''', + 'parts': 'Answer in plain text. Keep it simple and to the point.', 'role': 'user', }), dict({ @@ -82,22 +67,7 @@ dict({ 'history': list([ dict({ - 'parts': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. - ''', + 'parts': 'Answer in plain text. Keep it simple and to the point.', 'role': 'user', }), dict({ @@ -142,20 +112,8 @@ 'history': list([ dict({ 'parts': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Call the intent tools to control the system. Just pass the name to the intent. + Call the intent tools to control Home Assistant. Just pass the name to the intent. + Answer in plain text. Keep it simple and to the point. ''', 'role': 'user', }), @@ -201,20 +159,8 @@ 'history': list([ dict({ 'parts': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Call the intent tools to control the system. Just pass the name to the intent. + Call the intent tools to control Home Assistant. Just pass the name to the intent. + Answer in plain text. Keep it simple and to the point. ''', 'role': 'user', }), diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index a4972d03496..460d74734ae 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -13,7 +13,6 @@ from homeassistant.components.google_generative_ai_conversation.const import ( CONF_PROMPT, CONF_RECOMMENDED, CONF_TEMPERATURE, - CONF_TONE_PROMPT, CONF_TOP_K, CONF_TOP_P, DOMAIN, @@ -90,7 +89,7 @@ async def test_form(hass: HomeAssistant) -> None: assert result2["options"] == { CONF_RECOMMENDED: True, CONF_LLM_HASS_API: llm.LLM_API_ASSIST, - CONF_TONE_PROMPT: "", + CONF_PROMPT: "", } assert len(mock_setup_entry.mock_calls) == 1 @@ -102,7 +101,7 @@ async def test_form(hass: HomeAssistant) -> None: { CONF_RECOMMENDED: True, CONF_LLM_HASS_API: "none", - CONF_TONE_PROMPT: "bla", + CONF_PROMPT: "bla", }, { CONF_RECOMMENDED: False, @@ -132,12 +131,12 @@ async def test_form(hass: HomeAssistant) -> None: { CONF_RECOMMENDED: True, CONF_LLM_HASS_API: "assist", - CONF_TONE_PROMPT: "", + CONF_PROMPT: "", }, { CONF_RECOMMENDED: True, CONF_LLM_HASS_API: "assist", - CONF_TONE_PROMPT: "", + CONF_PROMPT: "", }, ), ], diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 431feb9d482..319295374a7 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -11,7 +11,6 @@ from openai.types.chat.chat_completion_message_tool_call import ( Function, ) from openai.types.completion_usage import CompletionUsage -import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol @@ -19,148 +18,12 @@ from homeassistant.components import conversation from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - intent, - llm, -) +from homeassistant.helpers import intent, llm from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry -@pytest.mark.parametrize("agent_id", [None, "conversation.openai"]) -@pytest.mark.parametrize( - "config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}] -) -async def test_default_prompt( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - area_registry: ar.AreaRegistry, - device_registry: dr.DeviceRegistry, - snapshot: SnapshotAssertion, - agent_id: str, - config_entry_options: dict, -) -> None: - """Test that the default prompt works.""" - entry = MockConfigEntry(title=None) - entry.add_to_hass(hass) - for i in range(3): - area_registry.async_create(f"{i}Empty Area") - - if agent_id is None: - agent_id = mock_config_entry.entry_id - - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - **mock_config_entry.options, - CONF_LLM_HASS_API: llm.LLM_API_ASSIST, - }, - ) - - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "1234")}, - name="Test Device", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - ) - for i in range(3): - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", f"{i}abcd")}, - name="Test Service", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - entry_type=dr.DeviceEntryType.SERVICE, - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "5678")}, - name="Test Device 2", - manufacturer="Test Manufacturer 2", - model="Device 2", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "qwer")}, - name="Test Device 4", - suggested_area="Test Area 2", - ) - device = device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-disabled")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_update_device( - device.id, disabled_by=dr.DeviceEntryDisabler.USER - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-no-name")}, - manufacturer="Test Manufacturer NoName", - model="Test Model NoName", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-integer-values")}, - name=1, - manufacturer=2, - model=3, - suggested_area="Test Area 2", - ) - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - return_value=ChatCompletion( - id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage( - content="Hello, how can I help you?", - role="assistant", - function_call=None, - tool_calls=None, - ), - ) - ], - created=1700000000, - model="gpt-3.5-turbo-0613", - object="chat.completion", - system_fingerprint=None, - usage=CompletionUsage( - completion_tokens=9, prompt_tokens=8, total_tokens=17 - ), - ), - ) as mock_create: - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=agent_id - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert mock_create.mock_calls[0][2]["messages"] == snapshot - - async def test_error_handling( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component ) -> None: diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 5dbb20ca86b..ca8edc507a0 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -20,11 +20,15 @@ async def test_register_api(hass: HomeAssistant) -> None: """Test registering an llm api.""" class MyAPI(llm.API): + async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str: + """Return a prompt for the tool.""" + return "" + def async_get_tools(self) -> list[llm.Tool]: """Return a list of tools.""" return [] - api = MyAPI(hass=hass, id="test", name="Test", prompt_template="") + api = MyAPI(hass=hass, id="test", name="Test") llm.async_register_api(hass, api) assert llm.async_get_api(hass, "test") is api