Add shared history for conversation agents (#135903)

* Add shared history for conversation agents

* Remove unused code

* Add support for native history items

* Store all assistant responses as assistant in history

* Add history support to DefaultAgent.async_handle_intents

* Make local fallback work

* Add default agent history

* Add history cleanup

* Add tests

* ChatHistory -> ChatSession

* Address comments

* Update snapshots
This commit is contained in:
Paulus Schoutsen 2025-01-18 22:33:03 -05:00 committed by GitHub
parent 32d7a23bff
commit 754de6f998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 744 additions and 276 deletions

View File

@ -1065,7 +1065,8 @@ class PipelineRun:
) )
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
conversation_result: conversation.ConversationResult | None = None agent_id = user_input.agent_id
intent_response: intent.IntentResponse | None = None
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT: if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
# Sentence triggers override conversation agent # Sentence triggers override conversation agent
if ( if (
@ -1075,14 +1076,12 @@ class PipelineRun:
) )
) is not None: ) is not None:
# Sentence trigger matched # Sentence trigger matched
trigger_response = intent.IntentResponse( agent_id = "sentence_trigger"
intent_response = intent.IntentResponse(
self.pipeline.conversation_language self.pipeline.conversation_language
) )
trigger_response.async_set_speech(trigger_response_text) intent_response.async_set_speech(trigger_response_text)
conversation_result = conversation.ConversationResult(
response=trigger_response,
conversation_id=user_input.conversation_id,
)
# Try local intents first, if preferred. # Try local intents first, if preferred.
elif self.pipeline.prefer_local_intents and ( elif self.pipeline.prefer_local_intents and (
intent_response := await conversation.async_handle_intents( intent_response := await conversation.async_handle_intents(
@ -1090,13 +1089,31 @@ class PipelineRun:
) )
): ):
# Local intent matched # Local intent matched
conversation_result = conversation.ConversationResult( agent_id = conversation.HOME_ASSISTANT_AGENT
response=intent_response,
conversation_id=user_input.conversation_id,
)
processed_locally = True processed_locally = True
if conversation_result is None: # It was already handled, create response and add to chat history
if intent_response is not None:
async with conversation.async_get_chat_session(
self.hass, user_input
) as chat_session:
speech: str = intent_response.speech.get("plain", {}).get(
"speech", ""
)
chat_session.async_add_message(
conversation.ChatMessage(
role="assistant",
agent_id=agent_id,
content=speech,
native=intent_response,
)
)
conversation_result = conversation.ConversationResult(
response=intent_response,
conversation_id=chat_session.conversation_id,
)
else:
# Fall back to pipeline conversation agent # Fall back to pipeline conversation agent
conversation_result = await conversation.async_converse( conversation_result = await conversation.async_converse(
hass=self.hass, hass=self.hass,
@ -1107,6 +1124,10 @@ class PipelineRun:
language=user_input.language, language=user_input.language,
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
) )
speech = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
except Exception as src_error: except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition") _LOGGER.exception("Unexpected error during intent recognition")
raise IntentRecognitionError( raise IntentRecognitionError(
@ -1126,10 +1147,6 @@ class PipelineRun:
) )
) )
speech: str = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
return speech return speech
async def prepare_text_to_speech(self) -> None: async def prepare_text_to_speech(self) -> None:

View File

@ -48,20 +48,25 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"HOME_ASSISTANT_AGENT", "HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT",
"ChatMessage",
"ChatSession",
"ConversationEntity", "ConversationEntity",
"ConversationEntityFeature", "ConversationEntityFeature",
"ConversationInput", "ConversationInput",
"ConversationResult", "ConversationResult",
"ConversationTraceEventType", "ConversationTraceEventType",
"ConverseError",
"async_conversation_trace_append", "async_conversation_trace_append",
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",
"async_get_chat_session",
"async_set_agent", "async_set_agent",
"async_setup", "async_setup",
"async_unset_agent", "async_unset_agent",

View File

@ -62,6 +62,7 @@ from .const import (
) )
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
from .session import ChatMessage, async_get_chat_session
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -346,35 +347,52 @@ class DefaultAgent(ConversationEntity):
async def async_process(self, user_input: ConversationInput) -> ConversationResult: async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence.""" """Process a sentence."""
response: intent.IntentResponse | None = None
async with async_get_chat_session(self.hass, user_input) as chat_session:
# Check if a trigger matched
if trigger_result := await self.async_recognize_sentence_trigger(
user_input
):
# Process callbacks and get response
response_text = await self._handle_trigger_result(
trigger_result, user_input
)
# Check if a trigger matched # Convert to conversation result
if trigger_result := await self.async_recognize_sentence_trigger(user_input): response = intent.IntentResponse(
# Process callbacks and get response language=user_input.language or self.hass.config.language
response_text = await self._handle_trigger_result( )
trigger_result, user_input response.response_type = intent.IntentResponseType.ACTION_DONE
response.async_set_speech(response_text)
if response is None:
# Match intents
intent_result = await self.async_recognize_intent(user_input)
response = await self._async_process_intent_result(
intent_result, user_input
)
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_session.async_add_message(
ChatMessage(
role="assistant",
agent_id=user_input.agent_id,
content=speech,
native=response,
)
) )
# Convert to conversation result return ConversationResult(
response = intent.IntentResponse( response=response, conversation_id=chat_session.conversation_id
language=user_input.language or self.hass.config.language
) )
response.response_type = intent.IntentResponseType.ACTION_DONE
response.async_set_speech(response_text)
return ConversationResult(response=response)
# Match intents
intent_result = await self.async_recognize_intent(user_input)
return await self._async_process_intent_result(intent_result, user_input)
async def _async_process_intent_result( async def _async_process_intent_result(
self, self,
result: RecognizeResult | None, result: RecognizeResult | None,
user_input: ConversationInput, user_input: ConversationInput,
) -> ConversationResult: ) -> intent.IntentResponse:
"""Process user input with intents.""" """Process user input with intents."""
language = user_input.language or self.hass.config.language language = user_input.language or self.hass.config.language
conversation_id = None # Not supported
# Intent match or failure # Intent match or failure
lang_intents = await self.async_get_or_load_intents(language) lang_intents = await self.async_get_or_load_intents(language)
@ -386,7 +404,6 @@ class DefaultAgent(ConversationEntity):
language, language,
intent.IntentResponseErrorCode.NO_INTENT_MATCH, intent.IntentResponseErrorCode.NO_INTENT_MATCH,
self._get_error_text(ErrorKey.NO_INTENT, lang_intents), self._get_error_text(ErrorKey.NO_INTENT, lang_intents),
conversation_id,
) )
if result.unmatched_entities: if result.unmatched_entities:
@ -408,7 +425,6 @@ class DefaultAgent(ConversationEntity):
self._get_error_text( self._get_error_text(
error_response_type, lang_intents, **error_response_args error_response_type, lang_intents, **error_response_args
), ),
conversation_id,
) )
# Will never happen because result will be None when no intents are # Will never happen because result will be None when no intents are
@ -461,7 +477,6 @@ class DefaultAgent(ConversationEntity):
self._get_error_text( self._get_error_text(
error_response_type, lang_intents, **error_response_args error_response_type, lang_intents, **error_response_args
), ),
conversation_id,
) )
except intent.IntentHandleError as err: except intent.IntentHandleError as err:
# Intent was valid and entities matched constraints, but an error # Intent was valid and entities matched constraints, but an error
@ -473,7 +488,6 @@ class DefaultAgent(ConversationEntity):
self._get_error_text( self._get_error_text(
err.response_key or ErrorKey.HANDLE_ERROR, lang_intents err.response_key or ErrorKey.HANDLE_ERROR, lang_intents
), ),
conversation_id,
) )
except intent.IntentUnexpectedError: except intent.IntentUnexpectedError:
_LOGGER.exception("Unexpected intent error") _LOGGER.exception("Unexpected intent error")
@ -481,7 +495,6 @@ class DefaultAgent(ConversationEntity):
language, language,
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
self._get_error_text(ErrorKey.HANDLE_ERROR, lang_intents), self._get_error_text(ErrorKey.HANDLE_ERROR, lang_intents),
conversation_id,
) )
if ( if (
@ -500,9 +513,7 @@ class DefaultAgent(ConversationEntity):
) )
intent_response.async_set_speech(speech) intent_response.async_set_speech(speech)
return ConversationResult( return intent_response
response=intent_response, conversation_id=conversation_id
)
def _recognize( def _recognize(
self, self,
@ -1346,22 +1357,18 @@ class DefaultAgent(ConversationEntity):
# No error message on failed match # No error message on failed match
return None return None
conversation_result = await self._async_process_intent_result( return await self._async_process_intent_result(result, user_input)
result, user_input
)
return conversation_result.response
def _make_error_result( def _make_error_result(
language: str, language: str,
error_code: intent.IntentResponseErrorCode, error_code: intent.IntentResponseErrorCode,
response_text: str, response_text: str,
conversation_id: str | None = None, ) -> intent.IntentResponse:
) -> ConversationResult:
"""Create conversation result with error code and text.""" """Create conversation result with error code and text."""
response = intent.IntentResponse(language=language) response = intent.IntentResponse(language=language)
response.async_set_error(error_code, response_text) response.async_set_error(error_code, response_text)
return ConversationResult(response, conversation_id) return response
def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str, Any]]: def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str, Any]]:

View File

@ -0,0 +1,327 @@
"""Conversation history."""
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
import logging
from typing import Generic, Literal, TypeVar
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HassJobType,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.event import async_call_later
from homeassistant.util import dt as dt_util, ulid
from homeassistant.util.hass_dict import HassKey
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
DATA_CHAT_HISTORY: HassKey["dict[str, ChatSession]"] = HassKey(
"conversation_chat_session"
)
DATA_CHAT_HISTORY_CLEANUP: HassKey["SessionCleanup"] = HassKey(
"conversation_chat_session_cleanup"
)
LOGGER = logging.getLogger(__name__)
CONVERSATION_TIMEOUT = timedelta(minutes=5)
_NativeT = TypeVar("_NativeT")
class SessionCleanup:
"""Helper to clean up the history."""
unsub: CALLBACK_TYPE | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the history cleanup."""
self.hass = hass
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
self.cleanup_job = HassJob(
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
)
@callback
def schedule(self) -> None:
"""Schedule the cleanup."""
if self.unsub:
return
self.unsub = async_call_later(
self.hass,
CONVERSATION_TIMEOUT.total_seconds() + 1,
self.cleanup_job,
)
@callback
def _on_hass_stop(self, event: Event) -> None:
"""Cancel the cleanup on shutdown."""
if self.unsub:
self.unsub()
self.unsub = None
@callback
def _cleanup(self, now: datetime) -> None:
"""Clean up the history and schedule follow-up if necessary."""
self.unsub = None
all_history = self.hass.data[DATA_CHAT_HISTORY]
# We mutate original object because current commands could be
# yielding history based on it.
for conversation_id, history in list(all_history.items()):
if history.last_updated + CONVERSATION_TIMEOUT < now:
del all_history[conversation_id]
# Still conversations left, check again in timeout time.
if all_history:
self.schedule()
@asynccontextmanager
async def async_get_chat_session(
hass: HomeAssistant,
user_input: ConversationInput,
) -> AsyncGenerator["ChatSession"]:
"""Return chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
if all_history is None:
all_history = {}
hass.data[DATA_CHAT_HISTORY] = all_history
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
history: ChatSession | None = None
if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()
elif history := all_history.get(user_input.conversation_id):
conversation_id = user_input.conversation_id
else:
# Conversation IDs are ULIDs. We generate a new one if not provided.
# If an old OLID is passed in, we will generate a new one to indicate
# a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it.
try:
ulid.ulid_to_bytes(user_input.conversation_id)
conversation_id = ulid.ulid_now()
except ValueError:
conversation_id = user_input.conversation_id
if history:
history = replace(history, messages=history.messages.copy())
else:
history = ChatSession(hass, conversation_id)
message: ChatMessage = ChatMessage(
role="user",
agent_id=user_input.agent_id,
content=user_input.text,
)
history.async_add_message(message)
yield history
if history.messages[-1] is message:
LOGGER.debug(
"History opened but no assistant message was added, ignoring update"
)
return
history.last_updated = dt_util.utcnow()
all_history[conversation_id] = history
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
class ConverseError(HomeAssistantError):
"""Error during initialization of conversation.
Will not be stored in the history.
"""
def __init__(
self, message: str, conversation_id: str, response: intent.IntentResponse
) -> None:
"""Initialize the error."""
super().__init__(message)
self.conversation_id = conversation_id
self.response = response
def as_converstation_result(self) -> ConversationResult:
"""Return the error as a conversation result."""
return ConversationResult(
response=self.response,
conversation_id=self.conversation_id,
)
@dataclass
class ChatMessage(Generic[_NativeT]):
"""Base class for chat messages.
When role is native, the content is to be ignored and message
is only meant for storing the native object.
"""
role: Literal["system", "assistant", "user", "native"]
agent_id: str | None
content: str
native: _NativeT | None = field(default=None)
# Validate in post-init that if role is native, there is no content and a native object exists
def __post_init__(self) -> None:
"""Validate native message."""
if self.role == "native" and self.native is None:
raise ValueError("Native message must have a native object")
@dataclass
class ChatSession(Generic[_NativeT]):
"""Class holding all information for a specific conversation."""
hass: HomeAssistant
conversation_id: str
user_name: str | None = None
messages: list[ChatMessage[_NativeT]] = field(
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
)
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
"""Process intent."""
if message.role == "system":
raise ValueError("Cannot add system messages to history")
if message.role != "native" and self.messages[-1].role == message.role:
raise ValueError("Cannot add two assistant or user messages in a row")
self.messages.append(message)
@callback
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs.
It can still include assistant/user messages generated by other agents.
"""
return [
message
for message in self.messages
if message.role != "native" or message.agent_id == agent_id
]
async def async_process_llm_message(
self,
conversing_domain: str,
user_input: ConversationInput,
user_llm_hass_api: str | None = None,
user_llm_prompt: str | None = None,
) -> None:
"""Process an incoming message for an LLM."""
llm_context = llm.LLMContext(
platform=conversing_domain,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=DOMAIN,
device_id=user_input.device_id,
)
llm_api: llm.APIInstance | None = None
if user_llm_hass_api:
try:
llm_api = await llm.async_get_api(
self.hass,
user_llm_hass_api,
llm_context,
)
except HomeAssistantError as err:
LOGGER.error(
"Error getting LLM API %s for %s: %s",
user_llm_hass_api,
conversing_domain,
err,
)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Error preparing LLM API",
)
raise ConverseError(
f"Error getting LLM API {user_llm_hass_api}",
conversation_id=self.conversation_id,
response=intent_response,
) from err
user_name: str | None = None
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
try:
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem with my template",
)
raise ConverseError(
"Error rendering prompt",
conversation_id=self.conversation_id,
response=intent_response,
) from err
if llm_api:
prompt_parts.append(llm_api.api_prompt)
extra_system_prompt = (
# Take new system prompt if one was given
user_input.extra_system_prompt or self.extra_system_prompt
)
if extra_system_prompt:
prompt_parts.append(extra_system_prompt)
prompt = "\n".join(prompt_parts)
self.llm_api = llm_api
self.user_name = user_name
self.extra_system_prompt = extra_system_prompt
self.messages[0] = ChatMessage(
role="system",
agent_id=user_input.agent_id,
content=prompt,
)

View File

@ -1,9 +1,8 @@
"""Conversation support for OpenAI.""" """Conversation support for OpenAI."""
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field
import json import json
from typing import Any, Literal from typing import Any, Literal, cast
import openai import openai
from openai._types import NOT_GIVEN from openai._types import NOT_GIVEN
@ -12,10 +11,8 @@ from openai.types.chat import (
ChatCompletionMessage, ChatCompletionMessage,
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam, ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam, ChatCompletionToolMessageParam,
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionUserMessageParam,
) )
from openai.types.chat.chat_completion_message_tool_call_param import Function from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition from openai.types.shared_params import FunctionDefinition
@ -27,10 +24,9 @@ from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm, template from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from . import OpenAIConfigEntry from . import OpenAIConfigEntry
from .const import ( from .const import (
@ -74,12 +70,28 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec) return ChatCompletionToolParam(type="function", function=tool_spec)
@dataclass def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
class ChatHistory: """Convert from class to TypedDict."""
"""Class holding the chat history.""" tool_calls: list[ChatCompletionMessageToolCallParam] = []
if message.tool_calls:
extra_system_prompt: str | None = None tool_calls = [
messages: list[ChatCompletionMessageParam] = field(default_factory=list) ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
arguments=tool_call.function.arguments,
name=tool_call.function.name,
),
type=tool_call.type,
)
for tool_call in message.tool_calls
]
param = ChatCompletionAssistantMessageParam(
role=message.role,
content=message.content,
)
if tool_calls:
param["tool_calls"] = tool_calls
return param
class OpenAIConversationEntity( class OpenAIConversationEntity(
@ -93,7 +105,6 @@ class OpenAIConversationEntity(
def __init__(self, entry: OpenAIConfigEntry) -> None: def __init__(self, entry: OpenAIConfigEntry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.history: dict[str, ChatHistory] = {}
self._attr_unique_id = entry.entry_id self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)}, identifiers={(DOMAIN, entry.entry_id)},
@ -132,127 +143,56 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
async with conversation.async_get_chat_session(
self.hass, user_input
) as session:
return await self._async_call_api(user_input, session)
async def _async_call_api(
self,
user_input: conversation.ConversationInput,
session: conversation.ChatSession[ChatCompletionMessageParam],
) -> conversation.ConversationResult:
"""Call the API."""
options = self.entry.options options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: list[ChatCompletionToolParam] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if options.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
options[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Error preparing LLM API",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
history: ChatHistory | None = None
if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()
elif user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
history = self.history.get(conversation_id)
else:
# Conversation IDs are ULIDs. We generate a new one if not provided.
# If an old OLID is passed in, we will generate a new one to indicate
# a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it.
try:
ulid.ulid_to_bytes(user_input.conversation_id)
conversation_id = ulid.ulid_now()
except ValueError:
conversation_id = user_input.conversation_id
if history is None:
history = ChatHistory(user_input.extra_system_prompt)
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
try: try:
prompt_parts = [ await session.async_process_llm_message(
template.Template( DOMAIN,
llm.BASE_PROMPT user_input,
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), options.get(CONF_LLM_HASS_API),
self.hass, options.get(CONF_PROMPT),
).async_render( )
{ except conversation.ConverseError as err:
"ha_name": self.hass.config.location_name, return err.as_converstation_result()
"user_name": user_name,
"llm_context": llm_context, tools: list[ChatCompletionToolParam] | None = None
}, if session.llm_api:
parse_result=False, tools = [
) _format_tool(tool, session.llm_api.custom_serializer)
for tool in session.llm_api.tools
] ]
except TemplateError as err: messages: list[ChatCompletionMessageParam] = []
LOGGER.error("Error rendering prompt: %s", err) for message in session.async_get_messages(user_input.agent_id):
intent_response = intent.IntentResponse(language=user_input.language) if message.native is not None and message.agent_id == user_input.agent_id:
intent_response.async_set_error( messages.append(message.native)
intent.IntentResponseErrorCode.UNKNOWN, else:
"Sorry, I had a problem with my template", messages.append(
) cast(
return conversation.ConversationResult( ChatCompletionMessageParam,
response=intent_response, conversation_id=conversation_id {"role": message.role, "content": message.content},
) )
)
if llm_api: LOGGER.debug("Prompt: %s", messages)
prompt_parts.append(llm_api.api_prompt)
extra_system_prompt = (
# Take new system prompt if one was given
user_input.extra_system_prompt or history.extra_system_prompt
)
if extra_system_prompt:
prompt_parts.append(extra_system_prompt)
prompt = "\n".join(prompt_parts)
# Create a copy of the variable because we attach it to the trace
history = ChatHistory(
extra_system_prompt,
[
ChatCompletionSystemMessageParam(role="system", content=prompt),
*history.messages[1:],
ChatCompletionUserMessageParam(role="user", content=user_input.text),
],
)
LOGGER.debug("Prompt: %s", history.messages)
LOGGER.debug("Tools: %s", tools) LOGGER.debug("Tools: %s", tools)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": history.messages, "tools": llm_api.tools if llm_api else None}, {
"messages": session.messages,
"tools": session.llm_api.tools if session.llm_api else None,
},
) )
client = self.entry.runtime_data client = self.entry.runtime_data
@ -262,12 +202,12 @@ class OpenAIConversationEntity(
try: try:
result = await client.chat.completions.create( result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=history.messages, messages=messages,
tools=tools or NOT_GIVEN, tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id, user=session.conversation_id,
) )
except openai.OpenAIError as err: except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err) LOGGER.error("Error talking to OpenAI: %s", err)
@ -277,44 +217,26 @@ class OpenAIConversationEntity(
"Sorry, I had a problem talking to OpenAI", "Sorry, I had a problem talking to OpenAI",
) )
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=session.conversation_id
) )
LOGGER.debug("Response %s", result) LOGGER.debug("Response %s", result)
response = result.choices[0].message response = result.choices[0].message
messages.append(_message_convert(response))
def message_convert( session.async_add_message(
message: ChatCompletionMessage, conversation.ChatMessage(
) -> ChatCompletionMessageParam: role=response.role,
"""Convert from class to TypedDict.""" agent_id=user_input.agent_id,
tool_calls: list[ChatCompletionMessageToolCallParam] = [] content=response.content or "",
if message.tool_calls: native=messages[-1],
tool_calls = [ ),
ChatCompletionMessageToolCallParam( )
id=tool_call.id,
function=Function(
arguments=tool_call.function.arguments,
name=tool_call.function.name,
),
type=tool_call.type,
)
for tool_call in message.tool_calls
]
param = ChatCompletionAssistantMessageParam(
role=message.role,
content=message.content,
)
if tool_calls:
param["tool_calls"] = tool_calls
return param
history.messages.append(message_convert(response)) if not response.tool_calls or not session.llm_api:
tool_calls = response.tool_calls
if not tool_calls or not llm_api:
break break
for tool_call in tool_calls: for tool_call in response.tool_calls:
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments), tool_args=json.loads(tool_call.function.arguments),
@ -324,27 +246,33 @@ class OpenAIConversationEntity(
) )
try: try:
tool_response = await llm_api.async_call_tool(tool_input) tool_response = await session.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e: except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__} tool_response = {"error": type(e).__name__}
if str(e): if str(e):
tool_response["error_text"] = str(e) tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response) LOGGER.debug("Tool response: %s", tool_response)
history.messages.append( messages.append(
ChatCompletionToolMessageParam( ChatCompletionToolMessageParam(
role="tool", role="tool",
tool_call_id=tool_call.id, tool_call_id=tool_call.id,
content=json.dumps(tool_response), content=json.dumps(tool_response),
) )
) )
session.async_add_message(
self.history[conversation_id] = history conversation.ChatMessage(
role="native",
agent_id=user_input.agent_id,
content="",
native=messages[-1],
)
)
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "") intent_response.async_set_speech(response.content or "")
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=session.conversation_id
) )
async def _async_entry_update_listener( async def _async_entry_update_listener(

View File

@ -15,10 +15,6 @@ import voluptuous as vol
from voluptuous_openapi import UNSUPPORTED, convert from voluptuous_openapi import UNSUPPORTED, convert
from homeassistant.components.climate import INTENT_GET_TEMPERATURE from homeassistant.components.climate import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant import async_should_expose from homeassistant.components.homeassistant import async_should_expose
from homeassistant.components.intent import async_device_supports_timers from homeassistant.components.intent import async_device_supports_timers
@ -171,6 +167,12 @@ class APIInstance:
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response.""" """Call a LLM tool, validate args and return the response."""
# pylint: disable=import-outside-toplevel
from homeassistant.components.conversation import (
ConversationTraceEventType,
async_conversation_trace_append,
)
async_conversation_trace_append( async_conversation_trace_append(
ConversationTraceEventType.TOOL_CALL, ConversationTraceEventType.TOOL_CALL,
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args}, {"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},

View File

@ -44,7 +44,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -135,7 +135,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -226,7 +226,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -341,7 +341,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -446,7 +446,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -497,7 +497,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -548,7 +548,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),

View File

@ -42,7 +42,7 @@
# name: test_audio_pipeline.4 # name: test_audio_pipeline.4
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -125,7 +125,7 @@
# name: test_audio_pipeline_debug.4 # name: test_audio_pipeline_debug.4
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -220,7 +220,7 @@
# name: test_audio_pipeline_with_enhancements.4 # name: test_audio_pipeline_with_enhancements.4
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -325,7 +325,7 @@
# name: test_audio_pipeline_with_wake_word_no_timeout.6 # name: test_audio_pipeline_with_wake_word_no_timeout.6
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -585,7 +585,7 @@
# name: test_pipeline_empty_tts_output.2 # name: test_pipeline_empty_tts_output.2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -698,7 +698,7 @@
# name: test_text_only_pipeline[extra_msg0].2 # name: test_text_only_pipeline[extra_msg0].2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -744,7 +744,7 @@
# name: test_text_only_pipeline[extra_msg1].2 # name: test_text_only_pipeline[extra_msg1].2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),

View File

@ -1,7 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_custom_sentences # name: test_custom_sentences
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -26,7 +26,7 @@
# --- # ---
# name: test_custom_sentences.1 # name: test_custom_sentences.1
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -51,7 +51,7 @@
# --- # ---
# name: test_custom_sentences_config # name: test_custom_sentences_config
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -76,7 +76,7 @@
# --- # ---
# name: test_intent_alias_added_removed # name: test_intent_alias_added_removed
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -106,7 +106,7 @@
# --- # ---
# name: test_intent_alias_added_removed.1 # name: test_intent_alias_added_removed.1
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -136,7 +136,7 @@
# --- # ---
# name: test_intent_alias_added_removed.2 # name: test_intent_alias_added_removed.2
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -156,7 +156,7 @@
# --- # ---
# name: test_intent_conversion_not_expose_new # name: test_intent_conversion_not_expose_new
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -176,7 +176,7 @@
# --- # ---
# name: test_intent_conversion_not_expose_new.1 # name: test_intent_conversion_not_expose_new.1
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -206,7 +206,7 @@
# --- # ---
# name: test_intent_entity_added_removed # name: test_intent_entity_added_removed
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -236,7 +236,7 @@
# --- # ---
# name: test_intent_entity_added_removed.1 # name: test_intent_entity_added_removed.1
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -266,7 +266,7 @@
# --- # ---
# name: test_intent_entity_added_removed.2 # name: test_intent_entity_added_removed.2
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -296,7 +296,7 @@
# --- # ---
# name: test_intent_entity_added_removed.3 # name: test_intent_entity_added_removed.3
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -316,7 +316,7 @@
# --- # ---
# name: test_intent_entity_exposed # name: test_intent_entity_exposed
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -346,7 +346,7 @@
# --- # ---
# name: test_intent_entity_fail_if_unexposed # name: test_intent_entity_fail_if_unexposed
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -366,7 +366,7 @@
# --- # ---
# name: test_intent_entity_remove_custom_name # name: test_intent_entity_remove_custom_name
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -386,7 +386,7 @@
# --- # ---
# name: test_intent_entity_remove_custom_name.1 # name: test_intent_entity_remove_custom_name.1
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -416,7 +416,7 @@
# --- # ---
# name: test_intent_entity_remove_custom_name.2 # name: test_intent_entity_remove_custom_name.2
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -436,7 +436,7 @@
# --- # ---
# name: test_intent_entity_renamed # name: test_intent_entity_renamed
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -466,7 +466,7 @@
# --- # ---
# name: test_intent_entity_renamed.1 # name: test_intent_entity_renamed.1
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),

View File

@ -201,7 +201,7 @@
# --- # ---
# name: test_http_api_handle_failure # name: test_http_api_handle_failure
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -221,7 +221,7 @@
# --- # ---
# name: test_http_api_no_match # name: test_http_api_no_match
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -241,7 +241,7 @@
# --- # ---
# name: test_http_api_unexpected_failure # name: test_http_api_unexpected_failure
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -261,7 +261,7 @@
# --- # ---
# name: test_http_processing_intent[None] # name: test_http_processing_intent[None]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -291,7 +291,7 @@
# --- # ---
# name: test_http_processing_intent[conversation.home_assistant] # name: test_http_processing_intent[conversation.home_assistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -321,7 +321,7 @@
# --- # ---
# name: test_http_processing_intent[homeassistant] # name: test_http_processing_intent[homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -351,7 +351,7 @@
# --- # ---
# name: test_ws_api[payload0] # name: test_ws_api[payload0]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -371,7 +371,7 @@
# --- # ---
# name: test_ws_api[payload1] # name: test_ws_api[payload1]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -391,7 +391,7 @@
# --- # ---
# name: test_ws_api[payload2] # name: test_ws_api[payload2]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -411,7 +411,7 @@
# --- # ---
# name: test_ws_api[payload3] # name: test_ws_api[payload3]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -431,7 +431,7 @@
# --- # ---
# name: test_ws_api[payload4] # name: test_ws_api[payload4]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -451,7 +451,7 @@
# --- # ---
# name: test_ws_api[payload5] # name: test_ws_api[payload5]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),

View File

@ -1,7 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_custom_agent # name: test_custom_agent
dict({ dict({
'conversation_id': 'test-conv-id', 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -44,7 +44,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-None] # name: test_turn_on_intent[None-turn kitchen on-None]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -74,7 +74,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant] # name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -104,7 +104,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn kitchen on-homeassistant] # name: test_turn_on_intent[None-turn kitchen on-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -134,7 +134,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-None] # name: test_turn_on_intent[None-turn on kitchen-None]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -164,7 +164,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant] # name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -194,7 +194,7 @@
# --- # ---
# name: test_turn_on_intent[None-turn on kitchen-homeassistant] # name: test_turn_on_intent[None-turn on kitchen-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -224,7 +224,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-None] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-None]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -254,7 +254,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -284,7 +284,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant] # name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -314,7 +314,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-None] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-None]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -344,7 +344,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),
@ -374,7 +374,7 @@
# --- # ---
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant] # name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
dict({ dict({
'conversation_id': None, 'conversation_id': <ANY>,
'response': dict({ 'response': dict({
'card': dict({ 'card': dict({
}), }),

View File

@ -0,0 +1,171 @@
"""Test the conversation session."""
from collections.abc import Generator
from datetime import timedelta
from unittest.mock import Mock, patch
import pytest
from homeassistant.components.conversation import ConversationInput, session
from homeassistant.core import Context, HomeAssistant
from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed
@pytest.fixture
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
"""Return a conversation input instance."""
return ConversationInput(
text="Hello",
context=Context(),
conversation_id=None,
agent_id="mock-agent-id",
device_id=None,
language="en",
)
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.mark.parametrize(
("start_id", "given_id"),
[
(None, "mock-ulid"),
# This ULID is not known as a session
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
("not-a-ulid", "not-a-ulid"),
],
)
async def test_conversation_id(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
mock_ulid: Mock,
start_id: str | None,
given_id: str,
) -> None:
"""Test conversation ID generation."""
mock_conversation_input.conversation_id = start_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert chat_session.conversation_id == given_id
async def test_cleanup(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Mock cleanup of the conversation session."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 2
conversation_id = chat_session.conversation_id
# Generate session entry.
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
# Because we didn't add a message to the session in the last block,
# the conversation was not be persisted and we get a new ID
assert chat_session.conversation_id != conversation_id
conversation_id = chat_session.conversation_id
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
assert len(chat_session.messages) == 3
# Reuse conversation ID to ensure we can chat with same session
mock_conversation_input.conversation_id = conversation_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 4
assert chat_session.conversation_id == conversation_id
async_fire_time_changed(
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
)
# It should be cleaned up now and we start a new conversation
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert chat_session.conversation_id != conversation_id
assert len(chat_session.messages) == 2
async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
messages = chat_session.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == session.ChatMessage(
role="system",
agent_id=None,
content="",
)
assert messages[1] == session.ChatMessage(
role="user",
agent_id=mock_conversation_input.agent_id,
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
)
# Different agent, will be filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="another-mock-agent-id", content="", native=1
)
)
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
)
assert len(chat_session.messages) == 5
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 4
assert messages[2] == session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
assert messages[3] == session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)

View File

@ -1,7 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_unknown_hass_api # name: test_unknown_hass_api
dict({ dict({
'conversation_id': None, 'conversation_id': 'my-conversation-id',
'response': IntentResponse( 'response': IntentResponse(
card=dict({ card=dict({
}), }),

View File

@ -625,7 +625,11 @@ async def test_unknown_hass_api(
await hass.async_block_till_done() await hass.async_block_till_done()
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass,
"hello",
"my-conversation-id",
Context(),
agent_id=mock_config_entry.entry_id,
) )
assert result == snapshot assert result == snapshot

View File

@ -109,6 +109,8 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
serializable_data = cls._serializable_issue_registry_entry(data) serializable_data = cls._serializable_issue_registry_entry(data)
elif isinstance(data, dict) and "flow_id" in data and "handler" in data: elif isinstance(data, dict) and "flow_id" in data and "handler" in data:
serializable_data = cls._serializable_flow_result(data) serializable_data = cls._serializable_flow_result(data)
elif isinstance(data, dict) and set(data) == {"conversation_id", "response"}:
serializable_data = cls._serializable_conversation_result(data)
elif isinstance(data, vol.Schema): elif isinstance(data, vol.Schema):
serializable_data = voluptuous_serialize.convert(data) serializable_data = voluptuous_serialize.convert(data)
elif isinstance(data, ConfigEntry): elif isinstance(data, ConfigEntry):
@ -200,6 +202,11 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
"""Prepare a Home Assistant flow result for serialization.""" """Prepare a Home Assistant flow result for serialization."""
return FlowResultSnapshot(data | {"flow_id": ANY}) return FlowResultSnapshot(data | {"flow_id": ANY})
@classmethod
def _serializable_conversation_result(cls, data: dict) -> SerializableData:
"""Prepare a Home Assistant conversation result for serialization."""
return data | {"conversation_id": ANY}
@classmethod @classmethod
def _serializable_issue_registry_entry( def _serializable_issue_registry_entry(
cls, data: ir.IssueEntry cls, data: ir.IssueEntry