From 009c9e79ae7dbe8dca6222b1eb4f971b06760a07 Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Wed, 22 May 2024 04:24:46 +0300 Subject: [PATCH] LLM Tools: Add device_id (#117884) --- .../google_generative_ai_conversation/conversation.py | 1 + homeassistant/helpers/llm.py | 3 +++ .../google_generative_ai_conversation/test_conversation.py | 4 ++++ tests/helpers/test_llm.py | 3 +++ 4 files changed, 11 insertions(+) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 8e16e8eaceb..bc21a1a524a 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -240,6 +240,7 @@ class GoogleGenerativeAIConversationEntity( user_prompt=user_input.text, language=user_input.language, assistant=conversation.DOMAIN, + device_id=user_input.device_id, ) LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index a53d134276a..670f9eadda2 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -73,6 +73,7 @@ class ToolInput(ABC): user_prompt: str | None language: str | None assistant: str | None + device_id: str | None class Tool: @@ -125,6 +126,7 @@ class API(ABC): user_prompt=tool_input.user_prompt, language=tool_input.language, assistant=tool_input.assistant, + device_id=tool_input.device_id, ) return await tool.async_call(self.hass, _tool_input) @@ -160,6 +162,7 @@ class IntentTool(Tool): tool_input.context, tool_input.language, tool_input.assistant, + tool_input.device_id, ) return intent_response.as_dict() diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index eac97790420..76fe10a0d15 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -198,6 +198,7 @@ async def test_function_call( None, context, agent_id=agent_id, + device_id="test_device", ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE @@ -228,6 +229,7 @@ async def test_function_call( user_prompt="Please call the test function", language="en", assistant="conversation", + device_id="test_device", ), ) @@ -280,6 +282,7 @@ async def test_function_exception( None, context, agent_id=agent_id, + device_id="test_device", ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE @@ -310,6 +313,7 @@ async def test_function_exception( user_prompt="Please call the test function", language="en", assistant="conversation", + device_id="test_device", ), ) diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index b8f5755ae39..5dbb20ca86b 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -46,6 +46,7 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None: None, None, None, + None, ), ) @@ -87,6 +88,7 @@ async def test_assist_api(hass: HomeAssistant) -> None: user_prompt="test_text", language="*", assistant="test_assistant", + device_id="test_device", ) with patch( @@ -106,6 +108,7 @@ async def test_assist_api(hass: HomeAssistant) -> None: test_context, "*", "test_assistant", + "test_device", ) assert response == { "card": {},