mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 17:57:55 +00:00
Add shared history for conversation agents (#135903)
* Add shared history for conversation agents * Remove unused code * Add support for native history items * Store all assistant responses as assistant in history * Add history support to DefaultAgent.async_handle_intents * Make local fallback work * Add default agent history * Add history cleanup * Add tests * ChatHistory -> ChatSession * Address comments * Update snapshots
This commit is contained in:
parent
32d7a23bff
commit
754de6f998
@ -1065,7 +1065,8 @@ class PipelineRun:
|
||||
)
|
||||
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
conversation_result: conversation.ConversationResult | None = None
|
||||
agent_id = user_input.agent_id
|
||||
intent_response: intent.IntentResponse | None = None
|
||||
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
|
||||
# Sentence triggers override conversation agent
|
||||
if (
|
||||
@ -1075,14 +1076,12 @@ class PipelineRun:
|
||||
)
|
||||
) is not None:
|
||||
# Sentence trigger matched
|
||||
trigger_response = intent.IntentResponse(
|
||||
agent_id = "sentence_trigger"
|
||||
intent_response = intent.IntentResponse(
|
||||
self.pipeline.conversation_language
|
||||
)
|
||||
trigger_response.async_set_speech(trigger_response_text)
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=trigger_response,
|
||||
conversation_id=user_input.conversation_id,
|
||||
)
|
||||
intent_response.async_set_speech(trigger_response_text)
|
||||
|
||||
# Try local intents first, if preferred.
|
||||
elif self.pipeline.prefer_local_intents and (
|
||||
intent_response := await conversation.async_handle_intents(
|
||||
@ -1090,13 +1089,31 @@ class PipelineRun:
|
||||
)
|
||||
):
|
||||
# Local intent matched
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=user_input.conversation_id,
|
||||
)
|
||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||
processed_locally = True
|
||||
|
||||
if conversation_result is None:
|
||||
# It was already handled, create response and add to chat history
|
||||
if intent_response is not None:
|
||||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as chat_session:
|
||||
speech: str = intent_response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
conversation.ChatMessage(
|
||||
role="assistant",
|
||||
agent_id=agent_id,
|
||||
content=speech,
|
||||
native=intent_response,
|
||||
)
|
||||
)
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_session.conversation_id,
|
||||
)
|
||||
|
||||
else:
|
||||
# Fall back to pipeline conversation agent
|
||||
conversation_result = await conversation.async_converse(
|
||||
hass=self.hass,
|
||||
@ -1107,6 +1124,10 @@ class PipelineRun:
|
||||
language=user_input.language,
|
||||
agent_id=user_input.agent_id,
|
||||
)
|
||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during intent recognition")
|
||||
raise IntentRecognitionError(
|
||||
@ -1126,10 +1147,6 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
speech: str = conversation_result.response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
|
||||
return speech
|
||||
|
||||
async def prepare_text_to_speech(self) -> None:
|
||||
|
@ -48,20 +48,25 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"ChatMessage",
|
||||
"ChatSession",
|
||||
"ConversationEntity",
|
||||
"ConversationEntityFeature",
|
||||
"ConversationInput",
|
||||
"ConversationResult",
|
||||
"ConversationTraceEventType",
|
||||
"ConverseError",
|
||||
"async_conversation_trace_append",
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
"async_get_chat_session",
|
||||
"async_set_agent",
|
||||
"async_setup",
|
||||
"async_unset_agent",
|
||||
|
@ -62,6 +62,7 @@ from .const import (
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .session import ChatMessage, async_get_chat_session
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -346,35 +347,52 @@ class DefaultAgent(ConversationEntity):
|
||||
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
response: intent.IntentResponse | None = None
|
||||
async with async_get_chat_session(self.hass, user_input) as chat_session:
|
||||
# Check if a trigger matched
|
||||
if trigger_result := await self.async_recognize_sentence_trigger(
|
||||
user_input
|
||||
):
|
||||
# Process callbacks and get response
|
||||
response_text = await self._handle_trigger_result(
|
||||
trigger_result, user_input
|
||||
)
|
||||
|
||||
# Check if a trigger matched
|
||||
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
|
||||
# Process callbacks and get response
|
||||
response_text = await self._handle_trigger_result(
|
||||
trigger_result, user_input
|
||||
# Convert to conversation result
|
||||
response = intent.IntentResponse(
|
||||
language=user_input.language or self.hass.config.language
|
||||
)
|
||||
response.response_type = intent.IntentResponseType.ACTION_DONE
|
||||
response.async_set_speech(response_text)
|
||||
|
||||
if response is None:
|
||||
# Match intents
|
||||
intent_result = await self.async_recognize_intent(user_input)
|
||||
response = await self._async_process_intent_result(
|
||||
intent_result, user_input
|
||||
)
|
||||
|
||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||
chat_session.async_add_message(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
content=speech,
|
||||
native=response,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to conversation result
|
||||
response = intent.IntentResponse(
|
||||
language=user_input.language or self.hass.config.language
|
||||
return ConversationResult(
|
||||
response=response, conversation_id=chat_session.conversation_id
|
||||
)
|
||||
response.response_type = intent.IntentResponseType.ACTION_DONE
|
||||
response.async_set_speech(response_text)
|
||||
|
||||
return ConversationResult(response=response)
|
||||
|
||||
# Match intents
|
||||
intent_result = await self.async_recognize_intent(user_input)
|
||||
return await self._async_process_intent_result(intent_result, user_input)
|
||||
|
||||
async def _async_process_intent_result(
|
||||
self,
|
||||
result: RecognizeResult | None,
|
||||
user_input: ConversationInput,
|
||||
) -> ConversationResult:
|
||||
) -> intent.IntentResponse:
|
||||
"""Process user input with intents."""
|
||||
language = user_input.language or self.hass.config.language
|
||||
conversation_id = None # Not supported
|
||||
|
||||
# Intent match or failure
|
||||
lang_intents = await self.async_get_or_load_intents(language)
|
||||
@ -386,7 +404,6 @@ class DefaultAgent(ConversationEntity):
|
||||
language,
|
||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
||||
self._get_error_text(ErrorKey.NO_INTENT, lang_intents),
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
if result.unmatched_entities:
|
||||
@ -408,7 +425,6 @@ class DefaultAgent(ConversationEntity):
|
||||
self._get_error_text(
|
||||
error_response_type, lang_intents, **error_response_args
|
||||
),
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
# Will never happen because result will be None when no intents are
|
||||
@ -461,7 +477,6 @@ class DefaultAgent(ConversationEntity):
|
||||
self._get_error_text(
|
||||
error_response_type, lang_intents, **error_response_args
|
||||
),
|
||||
conversation_id,
|
||||
)
|
||||
except intent.IntentHandleError as err:
|
||||
# Intent was valid and entities matched constraints, but an error
|
||||
@ -473,7 +488,6 @@ class DefaultAgent(ConversationEntity):
|
||||
self._get_error_text(
|
||||
err.response_key or ErrorKey.HANDLE_ERROR, lang_intents
|
||||
),
|
||||
conversation_id,
|
||||
)
|
||||
except intent.IntentUnexpectedError:
|
||||
_LOGGER.exception("Unexpected intent error")
|
||||
@ -481,7 +495,6 @@ class DefaultAgent(ConversationEntity):
|
||||
language,
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
self._get_error_text(ErrorKey.HANDLE_ERROR, lang_intents),
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -500,9 +513,7 @@ class DefaultAgent(ConversationEntity):
|
||||
)
|
||||
intent_response.async_set_speech(speech)
|
||||
|
||||
return ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
return intent_response
|
||||
|
||||
def _recognize(
|
||||
self,
|
||||
@ -1346,22 +1357,18 @@ class DefaultAgent(ConversationEntity):
|
||||
# No error message on failed match
|
||||
return None
|
||||
|
||||
conversation_result = await self._async_process_intent_result(
|
||||
result, user_input
|
||||
)
|
||||
return conversation_result.response
|
||||
return await self._async_process_intent_result(result, user_input)
|
||||
|
||||
|
||||
def _make_error_result(
|
||||
language: str,
|
||||
error_code: intent.IntentResponseErrorCode,
|
||||
response_text: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> ConversationResult:
|
||||
) -> intent.IntentResponse:
|
||||
"""Create conversation result with error code and text."""
|
||||
response = intent.IntentResponse(language=language)
|
||||
response.async_set_error(error_code, response_text)
|
||||
return ConversationResult(response, conversation_id)
|
||||
return response
|
||||
|
||||
|
||||
def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str, Any]]:
|
||||
|
327
homeassistant/components/conversation/session.py
Normal file
327
homeassistant/components/conversation/session.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Conversation history."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
HassJob,
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import intent, llm, template
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util import dt as dt_util, ulid
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from .const import DOMAIN
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
DATA_CHAT_HISTORY: HassKey["dict[str, ChatSession]"] = HassKey(
|
||||
"conversation_chat_session"
|
||||
)
|
||||
DATA_CHAT_HISTORY_CLEANUP: HassKey["SessionCleanup"] = HassKey(
|
||||
"conversation_chat_session_cleanup"
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||
_NativeT = TypeVar("_NativeT")
|
||||
|
||||
|
||||
class SessionCleanup:
|
||||
"""Helper to clean up the history."""
|
||||
|
||||
unsub: CALLBACK_TYPE | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the history cleanup."""
|
||||
self.hass = hass
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
||||
self.cleanup_job = HassJob(
|
||||
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
|
||||
)
|
||||
|
||||
@callback
|
||||
def schedule(self) -> None:
|
||||
"""Schedule the cleanup."""
|
||||
if self.unsub:
|
||||
return
|
||||
self.unsub = async_call_later(
|
||||
self.hass,
|
||||
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
||||
self.cleanup_job,
|
||||
)
|
||||
|
||||
@callback
|
||||
def _on_hass_stop(self, event: Event) -> None:
|
||||
"""Cancel the cleanup on shutdown."""
|
||||
if self.unsub:
|
||||
self.unsub()
|
||||
self.unsub = None
|
||||
|
||||
@callback
|
||||
def _cleanup(self, now: datetime) -> None:
|
||||
"""Clean up the history and schedule follow-up if necessary."""
|
||||
self.unsub = None
|
||||
all_history = self.hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
# We mutate original object because current commands could be
|
||||
# yielding history based on it.
|
||||
for conversation_id, history in list(all_history.items()):
|
||||
if history.last_updated + CONVERSATION_TIMEOUT < now:
|
||||
del all_history[conversation_id]
|
||||
|
||||
# Still conversations left, check again in timeout time.
|
||||
if all_history:
|
||||
self.schedule()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_get_chat_session(
|
||||
hass: HomeAssistant,
|
||||
user_input: ConversationInput,
|
||||
) -> AsyncGenerator["ChatSession"]:
|
||||
"""Return chat session."""
|
||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||
if all_history is None:
|
||||
all_history = {}
|
||||
hass.data[DATA_CHAT_HISTORY] = all_history
|
||||
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
|
||||
|
||||
history: ChatSession | None = None
|
||||
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid.ulid_now()
|
||||
|
||||
elif history := all_history.get(user_input.conversation_id):
|
||||
conversation_id = user_input.conversation_id
|
||||
|
||||
else:
|
||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||
# If an old OLID is passed in, we will generate a new one to indicate
|
||||
# a new conversation was started. If the user picks their own, they
|
||||
# want to track a conversation and we respect it.
|
||||
try:
|
||||
ulid.ulid_to_bytes(user_input.conversation_id)
|
||||
conversation_id = ulid.ulid_now()
|
||||
except ValueError:
|
||||
conversation_id = user_input.conversation_id
|
||||
|
||||
if history:
|
||||
history = replace(history, messages=history.messages.copy())
|
||||
else:
|
||||
history = ChatSession(hass, conversation_id)
|
||||
|
||||
message: ChatMessage = ChatMessage(
|
||||
role="user",
|
||||
agent_id=user_input.agent_id,
|
||||
content=user_input.text,
|
||||
)
|
||||
history.async_add_message(message)
|
||||
|
||||
yield history
|
||||
|
||||
if history.messages[-1] is message:
|
||||
LOGGER.debug(
|
||||
"History opened but no assistant message was added, ignoring update"
|
||||
)
|
||||
return
|
||||
|
||||
history.last_updated = dt_util.utcnow()
|
||||
all_history[conversation_id] = history
|
||||
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
|
||||
|
||||
|
||||
class ConverseError(HomeAssistantError):
|
||||
"""Error during initialization of conversation.
|
||||
|
||||
Will not be stored in the history.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, conversation_id: str, response: intent.IntentResponse
|
||||
) -> None:
|
||||
"""Initialize the error."""
|
||||
super().__init__(message)
|
||||
self.conversation_id = conversation_id
|
||||
self.response = response
|
||||
|
||||
def as_converstation_result(self) -> ConversationResult:
|
||||
"""Return the error as a conversation result."""
|
||||
return ConversationResult(
|
||||
response=self.response,
|
||||
conversation_id=self.conversation_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage(Generic[_NativeT]):
|
||||
"""Base class for chat messages.
|
||||
|
||||
When role is native, the content is to be ignored and message
|
||||
is only meant for storing the native object.
|
||||
"""
|
||||
|
||||
role: Literal["system", "assistant", "user", "native"]
|
||||
agent_id: str | None
|
||||
content: str
|
||||
native: _NativeT | None = field(default=None)
|
||||
|
||||
# Validate in post-init that if role is native, there is no content and a native object exists
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate native message."""
|
||||
if self.role == "native" and self.native is None:
|
||||
raise ValueError("Native message must have a native object")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSession(Generic[_NativeT]):
|
||||
"""Class holding all information for a specific conversation."""
|
||||
|
||||
hass: HomeAssistant
|
||||
conversation_id: str
|
||||
user_name: str | None = None
|
||||
messages: list[ChatMessage[_NativeT]] = field(
|
||||
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
|
||||
)
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||
|
||||
@callback
|
||||
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
|
||||
"""Process intent."""
|
||||
if message.role == "system":
|
||||
raise ValueError("Cannot add system messages to history")
|
||||
if message.role != "native" and self.messages[-1].role == message.role:
|
||||
raise ValueError("Cannot add two assistant or user messages in a row")
|
||||
|
||||
self.messages.append(message)
|
||||
|
||||
@callback
|
||||
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
|
||||
"""Get messages for a specific agent ID.
|
||||
|
||||
This will filter out any native message tied to other agent IDs.
|
||||
It can still include assistant/user messages generated by other agents.
|
||||
"""
|
||||
return [
|
||||
message
|
||||
for message in self.messages
|
||||
if message.role != "native" or message.agent_id == agent_id
|
||||
]
|
||||
|
||||
async def async_process_llm_message(
|
||||
self,
|
||||
conversing_domain: str,
|
||||
user_input: ConversationInput,
|
||||
user_llm_hass_api: str | None = None,
|
||||
user_llm_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Process an incoming message for an LLM."""
|
||||
llm_context = llm.LLMContext(
|
||||
platform=conversing_domain,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
llm_api: llm.APIInstance | None = None
|
||||
|
||||
if user_llm_hass_api:
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
user_llm_hass_api,
|
||||
llm_context,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
LOGGER.error(
|
||||
"Error getting LLM API %s for %s: %s",
|
||||
user_llm_hass_api,
|
||||
conversing_domain,
|
||||
err,
|
||||
)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Error preparing LLM API",
|
||||
)
|
||||
raise ConverseError(
|
||||
f"Error getting LLM API {user_llm_hass_api}",
|
||||
conversation_id=self.conversation_id,
|
||||
response=intent_response,
|
||||
) from err
|
||||
|
||||
user_name: str | None = None
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
try:
|
||||
prompt_parts = [
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
]
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem with my template",
|
||||
)
|
||||
raise ConverseError(
|
||||
"Error rendering prompt",
|
||||
conversation_id=self.conversation_id,
|
||||
response=intent_response,
|
||||
) from err
|
||||
|
||||
if llm_api:
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
|
||||
extra_system_prompt = (
|
||||
# Take new system prompt if one was given
|
||||
user_input.extra_system_prompt or self.extra_system_prompt
|
||||
)
|
||||
|
||||
if extra_system_prompt:
|
||||
prompt_parts.append(extra_system_prompt)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
self.llm_api = llm_api
|
||||
self.user_name = user_name
|
||||
self.extra_system_prompt = extra_system_prompt
|
||||
self.messages[0] = ChatMessage(
|
||||
role="system",
|
||||
agent_id=user_input.agent_id,
|
||||
content=prompt,
|
||||
)
|
@ -1,9 +1,8 @@
|
||||
"""Conversation support for OpenAI."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import openai
|
||||
from openai._types import NOT_GIVEN
|
||||
@ -12,10 +11,8 @@ from openai.types.chat import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
@ -27,10 +24,9 @@ from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
from .const import (
|
||||
@ -74,12 +70,28 @@ def _format_tool(
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatHistory:
|
||||
"""Class holding the chat history."""
|
||||
|
||||
extra_system_prompt: str | None = None
|
||||
messages: list[ChatCompletionMessageParam] = field(default_factory=list)
|
||||
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||
"""Convert from class to TypedDict."""
|
||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=tool_call.function.arguments,
|
||||
name=tool_call.function.name,
|
||||
),
|
||||
type=tool_call.type,
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
)
|
||||
if tool_calls:
|
||||
param["tool_calls"] = tool_calls
|
||||
return param
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
@ -93,7 +105,6 @@ class OpenAIConversationEntity(
|
||||
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.history: dict[str, ChatHistory] = {}
|
||||
self._attr_unique_id = entry.entry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, entry.entry_id)},
|
||||
@ -132,127 +143,56 @@ class OpenAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as session:
|
||||
return await self._async_call_api(user_input, session)
|
||||
|
||||
async def _async_call_api(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatSession[ChatCompletionMessageParam],
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
options = self.entry.options
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
llm_api: llm.APIInstance | None = None
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
user_name: str | None = None
|
||||
llm_context = llm.LLMContext(
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
if options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
options[CONF_LLM_HASS_API],
|
||||
llm_context,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Error preparing LLM API",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = [
|
||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
|
||||
history: ChatHistory | None = None
|
||||
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid.ulid_now()
|
||||
|
||||
elif user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
history = self.history.get(conversation_id)
|
||||
|
||||
else:
|
||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||
# If an old OLID is passed in, we will generate a new one to indicate
|
||||
# a new conversation was started. If the user picks their own, they
|
||||
# want to track a conversation and we respect it.
|
||||
try:
|
||||
ulid.ulid_to_bytes(user_input.conversation_id)
|
||||
conversation_id = ulid.ulid_now()
|
||||
except ValueError:
|
||||
conversation_id = user_input.conversation_id
|
||||
|
||||
if history is None:
|
||||
history = ChatHistory(user_input.extra_system_prompt)
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
try:
|
||||
prompt_parts = [
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
await session.async_process_llm_message(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
options.get(CONF_PROMPT),
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_converstation_result()
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if session.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, session.llm_api.custom_serializer)
|
||||
for tool in session.llm_api.tools
|
||||
]
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem with my template",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
for message in session.async_get_messages(user_input.agent_id):
|
||||
if message.native is not None and message.agent_id == user_input.agent_id:
|
||||
messages.append(message.native)
|
||||
else:
|
||||
messages.append(
|
||||
cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": message.role, "content": message.content},
|
||||
)
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
|
||||
extra_system_prompt = (
|
||||
# Take new system prompt if one was given
|
||||
user_input.extra_system_prompt or history.extra_system_prompt
|
||||
)
|
||||
|
||||
if extra_system_prompt:
|
||||
prompt_parts.append(extra_system_prompt)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
# Create a copy of the variable because we attach it to the trace
|
||||
history = ChatHistory(
|
||||
extra_system_prompt,
|
||||
[
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
*history.messages[1:],
|
||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
||||
],
|
||||
)
|
||||
|
||||
LOGGER.debug("Prompt: %s", history.messages)
|
||||
LOGGER.debug("Prompt: %s", messages)
|
||||
LOGGER.debug("Tools: %s", tools)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{"messages": history.messages, "tools": llm_api.tools if llm_api else None},
|
||||
{
|
||||
"messages": session.messages,
|
||||
"tools": session.llm_api.tools if session.llm_api else None,
|
||||
},
|
||||
)
|
||||
|
||||
client = self.entry.runtime_data
|
||||
@ -262,12 +202,12 @@ class OpenAIConversationEntity(
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
messages=history.messages,
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
user=conversation_id,
|
||||
user=session.conversation_id,
|
||||
)
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
@ -277,44 +217,26 @@ class OpenAIConversationEntity(
|
||||
"Sorry, I had a problem talking to OpenAI",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=session.conversation_id
|
||||
)
|
||||
|
||||
LOGGER.debug("Response %s", result)
|
||||
response = result.choices[0].message
|
||||
messages.append(_message_convert(response))
|
||||
|
||||
def message_convert(
|
||||
message: ChatCompletionMessage,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert from class to TypedDict."""
|
||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=tool_call.function.arguments,
|
||||
name=tool_call.function.name,
|
||||
),
|
||||
type=tool_call.type,
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
)
|
||||
if tool_calls:
|
||||
param["tool_calls"] = tool_calls
|
||||
return param
|
||||
session.async_add_message(
|
||||
conversation.ChatMessage(
|
||||
role=response.role,
|
||||
agent_id=user_input.agent_id,
|
||||
content=response.content or "",
|
||||
native=messages[-1],
|
||||
),
|
||||
)
|
||||
|
||||
history.messages.append(message_convert(response))
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
if not tool_calls or not llm_api:
|
||||
if not response.tool_calls or not session.llm_api:
|
||||
break
|
||||
|
||||
for tool_call in tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
@ -324,27 +246,33 @@ class OpenAIConversationEntity(
|
||||
)
|
||||
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
tool_response = await session.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
|
||||
LOGGER.debug("Tool response: %s", tool_response)
|
||||
history.messages.append(
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=tool_call.id,
|
||||
content=json.dumps(tool_response),
|
||||
)
|
||||
)
|
||||
|
||||
self.history[conversation_id] = history
|
||||
session.async_add_message(
|
||||
conversation.ChatMessage(
|
||||
role="native",
|
||||
agent_id=user_input.agent_id,
|
||||
content="",
|
||||
native=messages[-1],
|
||||
)
|
||||
)
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response.content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
response=intent_response, conversation_id=session.conversation_id
|
||||
)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
|
@ -15,10 +15,6 @@ import voluptuous as vol
|
||||
from voluptuous_openapi import UNSUPPORTED, convert
|
||||
|
||||
from homeassistant.components.climate import INTENT_GET_TEMPERATURE
|
||||
from homeassistant.components.conversation import (
|
||||
ConversationTraceEventType,
|
||||
async_conversation_trace_append,
|
||||
)
|
||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.homeassistant import async_should_expose
|
||||
from homeassistant.components.intent import async_device_supports_timers
|
||||
@ -171,6 +167,12 @@ class APIInstance:
|
||||
|
||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from homeassistant.components.conversation import (
|
||||
ConversationTraceEventType,
|
||||
async_conversation_trace_append,
|
||||
)
|
||||
|
||||
async_conversation_trace_append(
|
||||
ConversationTraceEventType.TOOL_CALL,
|
||||
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
||||
|
@ -44,7 +44,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -135,7 +135,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -226,7 +226,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -341,7 +341,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -446,7 +446,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -497,7 +497,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -548,7 +548,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
|
@ -42,7 +42,7 @@
|
||||
# name: test_audio_pipeline.4
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -125,7 +125,7 @@
|
||||
# name: test_audio_pipeline_debug.4
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -220,7 +220,7 @@
|
||||
# name: test_audio_pipeline_with_enhancements.4
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -325,7 +325,7 @@
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.6
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -585,7 +585,7 @@
|
||||
# name: test_pipeline_empty_tts_output.2
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -698,7 +698,7 @@
|
||||
# name: test_text_only_pipeline[extra_msg0].2
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -744,7 +744,7 @@
|
||||
# name: test_text_only_pipeline[extra_msg1].2
|
||||
dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
|
@ -1,7 +1,7 @@
|
||||
# serializer version: 1
|
||||
# name: test_custom_sentences
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -26,7 +26,7 @@
|
||||
# ---
|
||||
# name: test_custom_sentences.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -51,7 +51,7 @@
|
||||
# ---
|
||||
# name: test_custom_sentences_config
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -76,7 +76,7 @@
|
||||
# ---
|
||||
# name: test_intent_alias_added_removed
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -106,7 +106,7 @@
|
||||
# ---
|
||||
# name: test_intent_alias_added_removed.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -136,7 +136,7 @@
|
||||
# ---
|
||||
# name: test_intent_alias_added_removed.2
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -156,7 +156,7 @@
|
||||
# ---
|
||||
# name: test_intent_conversion_not_expose_new
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -176,7 +176,7 @@
|
||||
# ---
|
||||
# name: test_intent_conversion_not_expose_new.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -206,7 +206,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_added_removed
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -236,7 +236,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_added_removed.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -266,7 +266,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_added_removed.2
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -296,7 +296,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_added_removed.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -316,7 +316,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_exposed
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -346,7 +346,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_fail_if_unexposed
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -366,7 +366,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_remove_custom_name
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -386,7 +386,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_remove_custom_name.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -416,7 +416,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_remove_custom_name.2
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -436,7 +436,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_renamed
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -466,7 +466,7 @@
|
||||
# ---
|
||||
# name: test_intent_entity_renamed.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
|
@ -201,7 +201,7 @@
|
||||
# ---
|
||||
# name: test_http_api_handle_failure
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -221,7 +221,7 @@
|
||||
# ---
|
||||
# name: test_http_api_no_match
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -241,7 +241,7 @@
|
||||
# ---
|
||||
# name: test_http_api_unexpected_failure
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -261,7 +261,7 @@
|
||||
# ---
|
||||
# name: test_http_processing_intent[None]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -291,7 +291,7 @@
|
||||
# ---
|
||||
# name: test_http_processing_intent[conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -321,7 +321,7 @@
|
||||
# ---
|
||||
# name: test_http_processing_intent[homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -351,7 +351,7 @@
|
||||
# ---
|
||||
# name: test_ws_api[payload0]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -371,7 +371,7 @@
|
||||
# ---
|
||||
# name: test_ws_api[payload1]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -391,7 +391,7 @@
|
||||
# ---
|
||||
# name: test_ws_api[payload2]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -411,7 +411,7 @@
|
||||
# ---
|
||||
# name: test_ws_api[payload3]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -431,7 +431,7 @@
|
||||
# ---
|
||||
# name: test_ws_api[payload4]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -451,7 +451,7 @@
|
||||
# ---
|
||||
# name: test_ws_api[payload5]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
|
@ -1,7 +1,7 @@
|
||||
# serializer version: 1
|
||||
# name: test_custom_agent
|
||||
dict({
|
||||
'conversation_id': 'test-conv-id',
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -44,7 +44,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn kitchen on-None]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -74,7 +74,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -104,7 +104,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn kitchen on-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -134,7 +134,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn on kitchen-None]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -164,7 +164,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -194,7 +194,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[None-turn on kitchen-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -224,7 +224,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-None]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -254,7 +254,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -284,7 +284,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -314,7 +314,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-None]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -344,7 +344,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
@ -374,7 +374,7 @@
|
||||
# ---
|
||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
|
171
tests/components/conversation/test_session.py
Normal file
171
tests/components/conversation/test_session.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Test the conversation session."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, session
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
||||
"""Return a conversation input instance."""
|
||||
return ConversationInput(
|
||||
text="Hello",
|
||||
context=Context(),
|
||||
conversation_id=None,
|
||||
agent_id="mock-agent-id",
|
||||
device_id=None,
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("start_id", "given_id"),
|
||||
[
|
||||
(None, "mock-ulid"),
|
||||
# This ULID is not known as a session
|
||||
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
|
||||
("not-a-ulid", "not-a-ulid"),
|
||||
],
|
||||
)
|
||||
async def test_conversation_id(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
mock_ulid: Mock,
|
||||
start_id: str | None,
|
||||
given_id: str,
|
||||
) -> None:
|
||||
"""Test conversation ID generation."""
|
||||
mock_conversation_input.conversation_id = start_id
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert chat_session.conversation_id == given_id
|
||||
|
||||
|
||||
async def test_cleanup(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Mock cleanup of the conversation session."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert len(chat_session.messages) == 2
|
||||
conversation_id = chat_session.conversation_id
|
||||
|
||||
# Generate session entry.
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
# Because we didn't add a message to the session in the last block,
|
||||
# the conversation was not be persisted and we get a new ID
|
||||
assert chat_session.conversation_id != conversation_id
|
||||
conversation_id = chat_session.conversation_id
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
assert len(chat_session.messages) == 3
|
||||
|
||||
# Reuse conversation ID to ensure we can chat with same session
|
||||
mock_conversation_input.conversation_id = conversation_id
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert len(chat_session.messages) == 4
|
||||
assert chat_session.conversation_id == conversation_id
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
|
||||
)
|
||||
|
||||
# It should be cleaned up now and we start a new conversation
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert chat_session.conversation_id != conversation_id
|
||||
assert len(chat_session.messages) == 2
|
||||
|
||||
|
||||
async def test_message_filtering(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
messages = chat_session.async_get_messages(agent_id=None)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == session.ChatMessage(
|
||||
role="system",
|
||||
agent_id=None,
|
||||
content="",
|
||||
)
|
||||
assert messages[1] == session.ChatMessage(
|
||||
role="user",
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content=mock_conversation_input.text,
|
||||
)
|
||||
# Cannot add a second user message in a row
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
native="assistant-reply-native",
|
||||
)
|
||||
)
|
||||
# Different agent, will be filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="native", agent_id="another-mock-agent-id", content="", native=1
|
||||
)
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="native", agent_id="mock-agent-id", content="", native=1
|
||||
)
|
||||
)
|
||||
|
||||
assert len(chat_session.messages) == 5
|
||||
|
||||
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
||||
assert len(messages) == 4
|
||||
|
||||
assert messages[2] == session.ChatMessage(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
native="assistant-reply-native",
|
||||
)
|
||||
assert messages[3] == session.ChatMessage(
|
||||
role="native", agent_id="mock-agent-id", content="", native=1
|
||||
)
|
@ -1,7 +1,7 @@
|
||||
# serializer version: 1
|
||||
# name: test_unknown_hass_api
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'my-conversation-id',
|
||||
'response': IntentResponse(
|
||||
card=dict({
|
||||
}),
|
||||
|
@ -625,7 +625,11 @@ async def test_unknown_hass_api(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
hass,
|
||||
"hello",
|
||||
"my-conversation-id",
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
|
||||
assert result == snapshot
|
||||
|
@ -109,6 +109,8 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
|
||||
serializable_data = cls._serializable_issue_registry_entry(data)
|
||||
elif isinstance(data, dict) and "flow_id" in data and "handler" in data:
|
||||
serializable_data = cls._serializable_flow_result(data)
|
||||
elif isinstance(data, dict) and set(data) == {"conversation_id", "response"}:
|
||||
serializable_data = cls._serializable_conversation_result(data)
|
||||
elif isinstance(data, vol.Schema):
|
||||
serializable_data = voluptuous_serialize.convert(data)
|
||||
elif isinstance(data, ConfigEntry):
|
||||
@ -200,6 +202,11 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
|
||||
"""Prepare a Home Assistant flow result for serialization."""
|
||||
return FlowResultSnapshot(data | {"flow_id": ANY})
|
||||
|
||||
@classmethod
|
||||
def _serializable_conversation_result(cls, data: dict) -> SerializableData:
|
||||
"""Prepare a Home Assistant conversation result for serialization."""
|
||||
return data | {"conversation_id": ANY}
|
||||
|
||||
@classmethod
|
||||
def _serializable_issue_registry_entry(
|
||||
cls, data: ir.IssueEntry
|
||||
|
Loading…
x
Reference in New Issue
Block a user