mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Return intent response from LLM chat log if available (#148522)
This commit is contained in:
parent
3449863eee
commit
1734b316d5
@ -6,7 +6,6 @@ from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import AnthropicConfigEntry
|
||||
@ -72,13 +71,4 @@ class AnthropicConversationEntity(
|
||||
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
response_content = chat_log.content[-1]
|
||||
if not isinstance(response_content, conversation.AssistantContent):
|
||||
raise TypeError("Last message must be an assistant message")
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response_content.content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
@ -61,6 +61,7 @@ from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
from .util import async_get_result_from_chat_log
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
@ -83,6 +84,7 @@ __all__ = [
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
"async_get_chat_log",
|
||||
"async_get_result_from_chat_log",
|
||||
"async_set_agent",
|
||||
"async_setup",
|
||||
"async_unset_agent",
|
||||
|
@ -196,6 +196,7 @@ class ChatLog:
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
delta_listener: Callable[[ChatLog, dict], None] | None = None
|
||||
llm_input_provided_index = 0
|
||||
|
||||
@property
|
||||
def continue_conversation(self) -> bool:
|
||||
@ -496,6 +497,7 @@ class ChatLog:
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
self.llm_input_provided_index = len(self.content)
|
||||
self.llm_api = llm_api
|
||||
self.extra_system_prompt = extra_system_prompt
|
||||
self.content[0] = SystemContent(content=prompt)
|
||||
|
47
homeassistant/components/conversation/util.py
Normal file
47
homeassistant/components/conversation/util.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Utility functions for conversation integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
from .chat_log import AssistantContent, ChatLog, ToolResultContent
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_result_from_chat_log(
|
||||
user_input: ConversationInput, chat_log: ChatLog
|
||||
) -> ConversationResult:
|
||||
"""Get the result from the chat log."""
|
||||
tool_results = [
|
||||
content.tool_result
|
||||
for content in chat_log.content[chat_log.llm_input_provided_index :]
|
||||
if isinstance(content, ToolResultContent)
|
||||
and isinstance(content.tool_result, llm.IntentResponseDict)
|
||||
]
|
||||
|
||||
if tool_results:
|
||||
intent_response = tool_results[-1].original
|
||||
else:
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
|
||||
if not isinstance((last_content := chat_log.content[-1]), AssistantContent):
|
||||
_LOGGER.error(
|
||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||
last_content,
|
||||
)
|
||||
raise HomeAssistantError("Unable to get response")
|
||||
|
||||
intent_response.async_set_speech(last_content.content or "")
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
@ -8,12 +8,10 @@ from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
||||
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
||||
from .const import CONF_PROMPT, DOMAIN
|
||||
from .entity import GoogleGenerativeAILLMBaseEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -84,16 +82,4 @@ class GoogleGenerativeAIConversationEntity(
|
||||
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
response = intent.IntentResponse(language=user_input.language)
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||
chat_log.content[-1],
|
||||
)
|
||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||
response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
@ -8,7 +8,6 @@ from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import OllamaConfigEntry
|
||||
@ -84,15 +83,4 @@ class OllamaConversationEntity(
|
||||
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
raise TypeError(
|
||||
f"Unexpected last message type: {type(chat_log.content[-1])}"
|
||||
)
|
||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
@ -15,7 +15,6 @@ from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
@ -131,11 +130,4 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
)
|
||||
)
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
@ -6,7 +6,6 @@ from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
@ -84,11 +83,4 @@ class OpenAIConversationEntity(
|
||||
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
assert type(chat_log.content[-1]) is conversation.AssistantContent
|
||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
@ -315,10 +315,23 @@ class IntentTool(Tool):
|
||||
assistant=llm_context.assistant,
|
||||
device_id=llm_context.device_id,
|
||||
)
|
||||
response = intent_response.as_dict()
|
||||
del response["language"]
|
||||
del response["card"]
|
||||
return response
|
||||
return IntentResponseDict(intent_response)
|
||||
|
||||
|
||||
class IntentResponseDict(dict):
|
||||
"""Dictionary to represent an intent response resulting from a tool call."""
|
||||
|
||||
def __init__(self, intent_response: Any) -> None:
|
||||
"""Initialize the dictionary."""
|
||||
if not isinstance(intent_response, intent.IntentResponse):
|
||||
super().__init__(intent_response)
|
||||
return
|
||||
|
||||
result = intent_response.as_dict()
|
||||
del result["language"]
|
||||
del result["card"]
|
||||
super().__init__(result)
|
||||
self.original = intent_response
|
||||
|
||||
|
||||
class NamespacedTool(Tool):
|
||||
|
@ -1,13 +1,14 @@
|
||||
"""Conversation test helpers."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.shopping_list import intent as sl_intent
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MockAgent
|
||||
@ -15,6 +16,14 @@ from . import MockAgent
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
|
||||
"""Mock agent that supports all languages."""
|
||||
@ -25,6 +34,19 @@ def mock_agent_support_all(hass: HomeAssistant) -> MockAgent:
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_input(hass: HomeAssistant) -> conversation.ConversationInput:
|
||||
"""Return a conversation input instance."""
|
||||
return conversation.ConversationInput(
|
||||
text="Hello",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
agent_id="mock-agent-id",
|
||||
device_id=None,
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_shopping_list_io():
|
||||
"""Stub out the persistence."""
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""Test the conversation session."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
@ -26,27 +25,6 @@ from homeassistant.util import dt as dt_util
|
||||
from tests.common import async_fire_time_changed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
||||
"""Return a conversation input instance."""
|
||||
return ConversationInput(
|
||||
text="Hello",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
agent_id="mock-agent-id",
|
||||
device_id=None,
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
async def test_cleanup(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
|
39
tests/components/conversation/test_util.py
Normal file
39
tests/components/conversation/test_util.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Tests for conversation utility functions."""
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import chat_session, intent, llm
|
||||
|
||||
|
||||
async def test_async_get_result_from_chat_log(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: conversation.ConversationInput,
|
||||
) -> None:
|
||||
"""Test getting result from chat log."""
|
||||
intent_response = intent.IntentResponse(language="en")
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
conversation.async_get_chat_log(
|
||||
hass, session, mock_conversation_input
|
||||
) as chat_log,
|
||||
):
|
||||
chat_log.content.extend(
|
||||
[
|
||||
conversation.ToolResultContent(
|
||||
agent_id="mock-agent-id",
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_name="mock-tool-name",
|
||||
tool_result=llm.IntentResponseDict(intent_response),
|
||||
),
|
||||
conversation.AssistantContent(
|
||||
agent_id="mock-agent-id",
|
||||
content="This is a response.",
|
||||
),
|
||||
]
|
||||
)
|
||||
result = conversation.async_get_result_from_chat_log(
|
||||
mock_conversation_input, chat_log
|
||||
)
|
||||
# Original intent response is returned with speech set
|
||||
assert result.response is intent_response
|
||||
assert result.response.speech["plain"]["speech"] == "This is a response."
|
@ -359,7 +359,7 @@ async def test_empty_response(
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
ERROR_GETTING_RESPONSE
|
||||
"Unable to get response"
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user