Return intent response from LLM chat log if available (#148522)

This commit is contained in:
Paulus Schoutsen 2025-07-16 16:16:01 +02:00 committed by GitHub
parent 3449863eee
commit 1734b316d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 139 additions and 88 deletions

View File

@ -6,7 +6,6 @@ from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import AnthropicConfigEntry from . import AnthropicConfigEntry
@ -72,13 +71,4 @@ class AnthropicConversationEntity(
await self._async_handle_chat_log(chat_log) await self._async_handle_chat_log(chat_log)
response_content = chat_log.content[-1] return conversation.async_get_result_from_chat_log(user_input, chat_log)
if not isinstance(response_content, conversation.AssistantContent):
raise TypeError("Last message must be an assistant message")
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_content.content or "")
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)

View File

@ -61,6 +61,7 @@ from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
from .util import async_get_result_from_chat_log
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
@ -83,6 +84,7 @@ __all__ = [
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",
"async_get_chat_log", "async_get_chat_log",
"async_get_result_from_chat_log",
"async_set_agent", "async_set_agent",
"async_setup", "async_setup",
"async_unset_agent", "async_unset_agent",

View File

@ -196,6 +196,7 @@ class ChatLog:
extra_system_prompt: str | None = None extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
delta_listener: Callable[[ChatLog, dict], None] | None = None delta_listener: Callable[[ChatLog, dict], None] | None = None
llm_input_provided_index = 0
@property @property
def continue_conversation(self) -> bool: def continue_conversation(self) -> bool:
@ -496,6 +497,7 @@ class ChatLog:
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
self.llm_input_provided_index = len(self.content)
self.llm_api = llm_api self.llm_api = llm_api
self.extra_system_prompt = extra_system_prompt self.extra_system_prompt = extra_system_prompt
self.content[0] = SystemContent(content=prompt) self.content[0] = SystemContent(content=prompt)

View File

@ -0,0 +1,47 @@
"""Utility functions for conversation integration."""
from __future__ import annotations
import logging
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from .chat_log import AssistantContent, ChatLog, ToolResultContent
from .models import ConversationInput, ConversationResult
_LOGGER = logging.getLogger(__name__)
@callback
def async_get_result_from_chat_log(
user_input: ConversationInput, chat_log: ChatLog
) -> ConversationResult:
"""Get the result from the chat log."""
tool_results = [
content.tool_result
for content in chat_log.content[chat_log.llm_input_provided_index :]
if isinstance(content, ToolResultContent)
and isinstance(content.tool_result, llm.IntentResponseDict)
]
if tool_results:
intent_response = tool_results[-1].original
else:
intent_response = intent.IntentResponse(language=user_input.language)
if not isinstance((last_content := chat_log.content[-1]), AssistantContent):
_LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
last_content,
)
raise HomeAssistantError("Unable to get response")
intent_response.async_set_speech(last_content.content or "")
return ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)

View File

@ -8,12 +8,10 @@ from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import CONF_PROMPT, DOMAIN, LOGGER from .const import CONF_PROMPT, DOMAIN
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity from .entity import GoogleGenerativeAILLMBaseEntity
async def async_setup_entry( async def async_setup_entry(
@ -84,16 +82,4 @@ class GoogleGenerativeAIConversationEntity(
await self._async_handle_chat_log(chat_log) await self._async_handle_chat_log(chat_log)
response = intent.IntentResponse(language=user_input.language) return conversation.async_get_result_from_chat_log(user_input, chat_log)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)

View File

@ -8,7 +8,6 @@ from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OllamaConfigEntry from . import OllamaConfigEntry
@ -84,15 +83,4 @@ class OllamaConversationEntity(
await self._async_handle_chat_log(chat_log) await self._async_handle_chat_log(chat_log)
# Create intent response return conversation.async_get_result_from_chat_log(user_input, chat_log)
intent_response = intent.IntentResponse(language=user_input.language)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise TypeError(
f"Unexpected last message type: {type(chat_log.content[-1])}"
)
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)

View File

@ -15,7 +15,6 @@ from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
@ -131,11 +130,4 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
) )
) )
intent_response = intent.IntentResponse(language=user_input.language) return conversation.async_get_result_from_chat_log(user_input, chat_log)
assert type(chat_log.content[-1]) is conversation.AssistantContent
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)

View File

@ -6,7 +6,6 @@ from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenAIConfigEntry from . import OpenAIConfigEntry
@ -84,11 +83,4 @@ class OpenAIConversationEntity(
await self._async_handle_chat_log(chat_log) await self._async_handle_chat_log(chat_log)
intent_response = intent.IntentResponse(language=user_input.language) return conversation.async_get_result_from_chat_log(user_input, chat_log)
assert type(chat_log.content[-1]) is conversation.AssistantContent
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response,
conversation_id=chat_log.conversation_id,
continue_conversation=chat_log.continue_conversation,
)

View File

@ -315,10 +315,23 @@ class IntentTool(Tool):
assistant=llm_context.assistant, assistant=llm_context.assistant,
device_id=llm_context.device_id, device_id=llm_context.device_id,
) )
response = intent_response.as_dict() return IntentResponseDict(intent_response)
del response["language"]
del response["card"]
return response class IntentResponseDict(dict):
"""Dictionary to represent an intent response resulting from a tool call."""
def __init__(self, intent_response: Any) -> None:
"""Initialize the dictionary."""
if not isinstance(intent_response, intent.IntentResponse):
super().__init__(intent_response)
return
result = intent_response.as_dict()
del result["language"]
del result["card"]
super().__init__(result)
self.original = intent_response
class NamespacedTool(Tool): class NamespacedTool(Tool):

View File

@ -1,13 +1,14 @@
"""Conversation test helpers.""" """Conversation test helpers."""
from unittest.mock import patch from collections.abc import Generator
from unittest.mock import Mock, patch
import pytest import pytest
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.shopping_list import intent as sl_intent from homeassistant.components.shopping_list import intent as sl_intent
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import MockAgent from . import MockAgent
@ -15,6 +16,14 @@ from . import MockAgent
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.fixture @pytest.fixture
def mock_agent_support_all(hass: HomeAssistant) -> MockAgent: def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
"""Mock agent that supports all languages.""" """Mock agent that supports all languages."""
@ -25,6 +34,19 @@ def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
return agent return agent
@pytest.fixture
def mock_conversation_input(hass: HomeAssistant) -> conversation.ConversationInput:
"""Return a conversation input instance."""
return conversation.ConversationInput(
text="Hello",
context=Context(),
conversation_id=None,
agent_id="mock-agent-id",
device_id=None,
language="en",
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_shopping_list_io(): def mock_shopping_list_io():
"""Stub out the persistence.""" """Stub out the persistence."""

View File

@ -1,6 +1,5 @@
"""Test the conversation session.""" """Test the conversation session."""
from collections.abc import Generator
from dataclasses import asdict from dataclasses import asdict
from datetime import timedelta from datetime import timedelta
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@ -26,27 +25,6 @@ from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
@pytest.fixture
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
"""Return a conversation input instance."""
return ConversationInput(
text="Hello",
context=Context(),
conversation_id=None,
agent_id="mock-agent-id",
device_id=None,
language="en",
)
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
async def test_cleanup( async def test_cleanup(
hass: HomeAssistant, hass: HomeAssistant,
mock_conversation_input: ConversationInput, mock_conversation_input: ConversationInput,

View File

@ -0,0 +1,39 @@
"""Tests for conversation utility functions."""
from homeassistant.components import conversation
from homeassistant.core import HomeAssistant
from homeassistant.helpers import chat_session, intent, llm
async def test_async_get_result_from_chat_log(
hass: HomeAssistant,
mock_conversation_input: conversation.ConversationInput,
) -> None:
"""Test getting result from chat log."""
intent_response = intent.IntentResponse(language="en")
with (
chat_session.async_get_chat_session(hass) as session,
conversation.async_get_chat_log(
hass, session, mock_conversation_input
) as chat_log,
):
chat_log.content.extend(
[
conversation.ToolResultContent(
agent_id="mock-agent-id",
tool_call_id="mock-tool-call-id",
tool_name="mock-tool-name",
tool_result=llm.IntentResponseDict(intent_response),
),
conversation.AssistantContent(
agent_id="mock-agent-id",
content="This is a response.",
),
]
)
result = conversation.async_get_result_from_chat_log(
mock_conversation_input, chat_log
)
# Original intent response is returned with speech set
assert result.response is intent_response
assert result.response.speech["plain"]["speech"] == "This is a response."

View File

@ -359,7 +359,7 @@ async def test_empty_response(
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == ( assert result.response.as_dict()["speech"]["plain"]["speech"] == (
ERROR_GETTING_RESPONSE "Unable to get response"
) )