LLM Tools: Add device_id (#117884)

This commit is contained in:
Denis Shulyaka 2024-05-22 04:24:46 +03:00 committed by GitHub
parent 4ed45a322c
commit 009c9e79ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 0 deletions

View File

@ -240,6 +240,7 @@ class GoogleGenerativeAIConversationEntity(
user_prompt=user_input.text, user_prompt=user_input.text,
language=user_input.language, language=user_input.language,
assistant=conversation.DOMAIN, assistant=conversation.DOMAIN,
device_id=user_input.device_id,
) )
LOGGER.debug( LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args

View File

@ -73,6 +73,7 @@ class ToolInput(ABC):
user_prompt: str | None user_prompt: str | None
language: str | None language: str | None
assistant: str | None assistant: str | None
device_id: str | None
class Tool: class Tool:
@ -125,6 +126,7 @@ class API(ABC):
user_prompt=tool_input.user_prompt, user_prompt=tool_input.user_prompt,
language=tool_input.language, language=tool_input.language,
assistant=tool_input.assistant, assistant=tool_input.assistant,
device_id=tool_input.device_id,
) )
return await tool.async_call(self.hass, _tool_input) return await tool.async_call(self.hass, _tool_input)
@ -160,6 +162,7 @@ class IntentTool(Tool):
tool_input.context, tool_input.context,
tool_input.language, tool_input.language,
tool_input.assistant, tool_input.assistant,
tool_input.device_id,
) )
return intent_response.as_dict() return intent_response.as_dict()

View File

@ -198,6 +198,7 @@ async def test_function_call(
None, None,
context, context,
agent_id=agent_id, agent_id=agent_id,
device_id="test_device",
) )
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
@ -228,6 +229,7 @@ async def test_function_call(
user_prompt="Please call the test function", user_prompt="Please call the test function",
language="en", language="en",
assistant="conversation", assistant="conversation",
device_id="test_device",
), ),
) )
@ -280,6 +282,7 @@ async def test_function_exception(
None, None,
context, context,
agent_id=agent_id, agent_id=agent_id,
device_id="test_device",
) )
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
@ -310,6 +313,7 @@ async def test_function_exception(
user_prompt="Please call the test function", user_prompt="Please call the test function",
language="en", language="en",
assistant="conversation", assistant="conversation",
device_id="test_device",
), ),
) )

View File

@ -46,6 +46,7 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
None, None,
None, None,
None, None,
None,
), ),
) )
@ -87,6 +88,7 @@ async def test_assist_api(hass: HomeAssistant) -> None:
user_prompt="test_text", user_prompt="test_text",
language="*", language="*",
assistant="test_assistant", assistant="test_assistant",
device_id="test_device",
) )
with patch( with patch(
@ -106,6 +108,7 @@ async def test_assist_api(hass: HomeAssistant) -> None:
test_context, test_context,
"*", "*",
"test_assistant", "test_assistant",
"test_device",
) )
assert response == { assert response == {
"card": {}, "card": {},