mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Return intent response from LLM chat log if available (#148522)
This commit is contained in:
parent
3449863eee
commit
1734b316d5
@ -6,7 +6,6 @@ from homeassistant.components import conversation
|
|||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import intent
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import AnthropicConfigEntry
|
from . import AnthropicConfigEntry
|
||||||
@ -72,13 +71,4 @@ class AnthropicConversationEntity(
|
|||||||
|
|
||||||
await self._async_handle_chat_log(chat_log)
|
await self._async_handle_chat_log(chat_log)
|
||||||
|
|
||||||
response_content = chat_log.content[-1]
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
if not isinstance(response_content, conversation.AssistantContent):
|
|
||||||
raise TypeError("Last message must be an assistant message")
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
intent_response.async_set_speech(response_content.content or "")
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response,
|
|
||||||
conversation_id=chat_log.conversation_id,
|
|
||||||
continue_conversation=chat_log.continue_conversation,
|
|
||||||
)
|
|
||||||
|
@ -61,6 +61,7 @@ from .entity import ConversationEntity
|
|||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
from .util import async_get_result_from_chat_log
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
@ -83,6 +84,7 @@ __all__ = [
|
|||||||
"async_converse",
|
"async_converse",
|
||||||
"async_get_agent_info",
|
"async_get_agent_info",
|
||||||
"async_get_chat_log",
|
"async_get_chat_log",
|
||||||
|
"async_get_result_from_chat_log",
|
||||||
"async_set_agent",
|
"async_set_agent",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_unset_agent",
|
"async_unset_agent",
|
||||||
|
@ -196,6 +196,7 @@ class ChatLog:
|
|||||||
extra_system_prompt: str | None = None
|
extra_system_prompt: str | None = None
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
delta_listener: Callable[[ChatLog, dict], None] | None = None
|
delta_listener: Callable[[ChatLog, dict], None] | None = None
|
||||||
|
llm_input_provided_index = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def continue_conversation(self) -> bool:
|
def continue_conversation(self) -> bool:
|
||||||
@ -496,6 +497,7 @@ class ChatLog:
|
|||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
|
self.llm_input_provided_index = len(self.content)
|
||||||
self.llm_api = llm_api
|
self.llm_api = llm_api
|
||||||
self.extra_system_prompt = extra_system_prompt
|
self.extra_system_prompt = extra_system_prompt
|
||||||
self.content[0] = SystemContent(content=prompt)
|
self.content[0] = SystemContent(content=prompt)
|
||||||
|
47
homeassistant/components/conversation/util.py
Normal file
47
homeassistant/components/conversation/util.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Utility functions for conversation integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import intent, llm
|
||||||
|
|
||||||
|
from .chat_log import AssistantContent, ChatLog, ToolResultContent
|
||||||
|
from .models import ConversationInput, ConversationResult
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_result_from_chat_log(
|
||||||
|
user_input: ConversationInput, chat_log: ChatLog
|
||||||
|
) -> ConversationResult:
|
||||||
|
"""Get the result from the chat log."""
|
||||||
|
tool_results = [
|
||||||
|
content.tool_result
|
||||||
|
for content in chat_log.content[chat_log.llm_input_provided_index :]
|
||||||
|
if isinstance(content, ToolResultContent)
|
||||||
|
and isinstance(content.tool_result, llm.IntentResponseDict)
|
||||||
|
]
|
||||||
|
|
||||||
|
if tool_results:
|
||||||
|
intent_response = tool_results[-1].original
|
||||||
|
else:
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
|
||||||
|
if not isinstance((last_content := chat_log.content[-1]), AssistantContent):
|
||||||
|
_LOGGER.error(
|
||||||
|
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||||
|
last_content,
|
||||||
|
)
|
||||||
|
raise HomeAssistantError("Unable to get response")
|
||||||
|
|
||||||
|
intent_response.async_set_speech(last_content.content or "")
|
||||||
|
|
||||||
|
return ConversationResult(
|
||||||
|
response=intent_response,
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
continue_conversation=chat_log.continue_conversation,
|
||||||
|
)
|
@ -8,12 +8,10 @@ from homeassistant.components import conversation
|
|||||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
from homeassistant.helpers import intent
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
from .const import CONF_PROMPT, DOMAIN
|
||||||
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
from .entity import GoogleGenerativeAILLMBaseEntity
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
@ -84,16 +82,4 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
|
|
||||||
await self._async_handle_chat_log(chat_log)
|
await self._async_handle_chat_log(chat_log)
|
||||||
|
|
||||||
response = intent.IntentResponse(language=user_input.language)
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
|
||||||
LOGGER.error(
|
|
||||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
|
||||||
chat_log.content[-1],
|
|
||||||
)
|
|
||||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
|
||||||
response.async_set_speech(chat_log.content[-1].content or "")
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=response,
|
|
||||||
conversation_id=chat_log.conversation_id,
|
|
||||||
continue_conversation=chat_log.continue_conversation,
|
|
||||||
)
|
|
||||||
|
@ -8,7 +8,6 @@ from homeassistant.components import conversation
|
|||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import intent
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import OllamaConfigEntry
|
from . import OllamaConfigEntry
|
||||||
@ -84,15 +83,4 @@ class OllamaConversationEntity(
|
|||||||
|
|
||||||
await self._async_handle_chat_log(chat_log)
|
await self._async_handle_chat_log(chat_log)
|
||||||
|
|
||||||
# Create intent response
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
|
||||||
raise TypeError(
|
|
||||||
f"Unexpected last message type: {type(chat_log.content[-1])}"
|
|
||||||
)
|
|
||||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response,
|
|
||||||
conversation_id=chat_log.conversation_id,
|
|
||||||
continue_conversation=chat_log.continue_conversation,
|
|
||||||
)
|
|
||||||
|
@ -15,7 +15,6 @@ from homeassistant.config_entries import ConfigSubentry
|
|||||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent
|
|
||||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
@ -131,11 +130,4 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
|
||||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response,
|
|
||||||
conversation_id=chat_log.conversation_id,
|
|
||||||
continue_conversation=chat_log.continue_conversation,
|
|
||||||
)
|
|
||||||
|
@ -6,7 +6,6 @@ from homeassistant.components import conversation
|
|||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import intent
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import OpenAIConfigEntry
|
from . import OpenAIConfigEntry
|
||||||
@ -84,11 +83,4 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
await self._async_handle_chat_log(chat_log)
|
await self._async_handle_chat_log(chat_log)
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
|
||||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response,
|
|
||||||
conversation_id=chat_log.conversation_id,
|
|
||||||
continue_conversation=chat_log.continue_conversation,
|
|
||||||
)
|
|
||||||
|
@ -315,10 +315,23 @@ class IntentTool(Tool):
|
|||||||
assistant=llm_context.assistant,
|
assistant=llm_context.assistant,
|
||||||
device_id=llm_context.device_id,
|
device_id=llm_context.device_id,
|
||||||
)
|
)
|
||||||
response = intent_response.as_dict()
|
return IntentResponseDict(intent_response)
|
||||||
del response["language"]
|
|
||||||
del response["card"]
|
|
||||||
return response
|
class IntentResponseDict(dict):
|
||||||
|
"""Dictionary to represent an intent response resulting from a tool call."""
|
||||||
|
|
||||||
|
def __init__(self, intent_response: Any) -> None:
|
||||||
|
"""Initialize the dictionary."""
|
||||||
|
if not isinstance(intent_response, intent.IntentResponse):
|
||||||
|
super().__init__(intent_response)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = intent_response.as_dict()
|
||||||
|
del result["language"]
|
||||||
|
del result["card"]
|
||||||
|
super().__init__(result)
|
||||||
|
self.original = intent_response
|
||||||
|
|
||||||
|
|
||||||
class NamespacedTool(Tool):
|
class NamespacedTool(Tool):
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
"""Conversation test helpers."""
|
"""Conversation test helpers."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.shopping_list import intent as sl_intent
|
from homeassistant.components.shopping_list import intent as sl_intent
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import MockAgent
|
from . import MockAgent
|
||||||
@ -15,6 +16,14 @@ from . import MockAgent
|
|||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ulid() -> Generator[Mock]:
|
||||||
|
"""Mock the ulid library."""
|
||||||
|
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||||
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
|
def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
|
||||||
"""Mock agent that supports all languages."""
|
"""Mock agent that supports all languages."""
|
||||||
@ -25,6 +34,19 @@ def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
|
|||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_conversation_input(hass: HomeAssistant) -> conversation.ConversationInput:
|
||||||
|
"""Return a conversation input instance."""
|
||||||
|
return conversation.ConversationInput(
|
||||||
|
text="Hello",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
device_id=None,
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_shopping_list_io():
|
def mock_shopping_list_io():
|
||||||
"""Stub out the persistence."""
|
"""Stub out the persistence."""
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""Test the conversation session."""
|
"""Test the conversation session."""
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
@ -26,27 +25,6 @@ from homeassistant.util import dt as dt_util
|
|||||||
from tests.common import async_fire_time_changed
|
from tests.common import async_fire_time_changed
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
|
||||||
"""Return a conversation input instance."""
|
|
||||||
return ConversationInput(
|
|
||||||
text="Hello",
|
|
||||||
context=Context(),
|
|
||||||
conversation_id=None,
|
|
||||||
agent_id="mock-agent-id",
|
|
||||||
device_id=None,
|
|
||||||
language="en",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_ulid() -> Generator[Mock]:
|
|
||||||
"""Mock the ulid library."""
|
|
||||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
|
||||||
mock_ulid_now.return_value = "mock-ulid"
|
|
||||||
yield mock_ulid_now
|
|
||||||
|
|
||||||
|
|
||||||
async def test_cleanup(
|
async def test_cleanup(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
|
39
tests/components/conversation/test_util.py
Normal file
39
tests/components/conversation/test_util.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
"""Tests for conversation utility functions."""
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import chat_session, intent, llm
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_get_result_from_chat_log(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_conversation_input: conversation.ConversationInput,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting result from chat log."""
|
||||||
|
intent_response = intent.IntentResponse(language="en")
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
|
conversation.async_get_chat_log(
|
||||||
|
hass, session, mock_conversation_input
|
||||||
|
) as chat_log,
|
||||||
|
):
|
||||||
|
chat_log.content.extend(
|
||||||
|
[
|
||||||
|
conversation.ToolResultContent(
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
tool_call_id="mock-tool-call-id",
|
||||||
|
tool_name="mock-tool-name",
|
||||||
|
tool_result=llm.IntentResponseDict(intent_response),
|
||||||
|
),
|
||||||
|
conversation.AssistantContent(
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
content="This is a response.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = conversation.async_get_result_from_chat_log(
|
||||||
|
mock_conversation_input, chat_log
|
||||||
|
)
|
||||||
|
# Original intent response is returned with speech set
|
||||||
|
assert result.response is intent_response
|
||||||
|
assert result.response.speech["plain"]["speech"] == "This is a response."
|
@ -359,7 +359,7 @@ async def test_empty_response(
|
|||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
ERROR_GETTING_RESPONSE
|
"Unable to get response"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user