mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Add shared history for conversation agents (#135903)
* Add shared history for conversation agents * Remove unused code * Add support for native history items * Store all assistant responses as assistant in history * Add history support to DefaultAgent.async_handle_intents * Make local fallback work * Add default agent history * Add history cleanup * Add tests * ChatHistory -> ChatSession * Address comments * Update snapshots
This commit is contained in:
parent
32d7a23bff
commit
754de6f998
@ -1065,7 +1065,8 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT
|
||||||
|
|
||||||
conversation_result: conversation.ConversationResult | None = None
|
agent_id = user_input.agent_id
|
||||||
|
intent_response: intent.IntentResponse | None = None
|
||||||
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
|
if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT:
|
||||||
# Sentence triggers override conversation agent
|
# Sentence triggers override conversation agent
|
||||||
if (
|
if (
|
||||||
@ -1075,14 +1076,12 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
) is not None:
|
) is not None:
|
||||||
# Sentence trigger matched
|
# Sentence trigger matched
|
||||||
trigger_response = intent.IntentResponse(
|
agent_id = "sentence_trigger"
|
||||||
|
intent_response = intent.IntentResponse(
|
||||||
self.pipeline.conversation_language
|
self.pipeline.conversation_language
|
||||||
)
|
)
|
||||||
trigger_response.async_set_speech(trigger_response_text)
|
intent_response.async_set_speech(trigger_response_text)
|
||||||
conversation_result = conversation.ConversationResult(
|
|
||||||
response=trigger_response,
|
|
||||||
conversation_id=user_input.conversation_id,
|
|
||||||
)
|
|
||||||
# Try local intents first, if preferred.
|
# Try local intents first, if preferred.
|
||||||
elif self.pipeline.prefer_local_intents and (
|
elif self.pipeline.prefer_local_intents and (
|
||||||
intent_response := await conversation.async_handle_intents(
|
intent_response := await conversation.async_handle_intents(
|
||||||
@ -1090,13 +1089,31 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
# Local intent matched
|
# Local intent matched
|
||||||
conversation_result = conversation.ConversationResult(
|
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||||
response=intent_response,
|
|
||||||
conversation_id=user_input.conversation_id,
|
|
||||||
)
|
|
||||||
processed_locally = True
|
processed_locally = True
|
||||||
|
|
||||||
if conversation_result is None:
|
# It was already handled, create response and add to chat history
|
||||||
|
if intent_response is not None:
|
||||||
|
async with conversation.async_get_chat_session(
|
||||||
|
self.hass, user_input
|
||||||
|
) as chat_session:
|
||||||
|
speech: str = intent_response.speech.get("plain", {}).get(
|
||||||
|
"speech", ""
|
||||||
|
)
|
||||||
|
chat_session.async_add_message(
|
||||||
|
conversation.ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
agent_id=agent_id,
|
||||||
|
content=speech,
|
||||||
|
native=intent_response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation_result = conversation.ConversationResult(
|
||||||
|
response=intent_response,
|
||||||
|
conversation_id=chat_session.conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
# Fall back to pipeline conversation agent
|
# Fall back to pipeline conversation agent
|
||||||
conversation_result = await conversation.async_converse(
|
conversation_result = await conversation.async_converse(
|
||||||
hass=self.hass,
|
hass=self.hass,
|
||||||
@ -1107,6 +1124,10 @@ class PipelineRun:
|
|||||||
language=user_input.language,
|
language=user_input.language,
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
)
|
)
|
||||||
|
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||||
|
"speech", ""
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during intent recognition")
|
_LOGGER.exception("Unexpected error during intent recognition")
|
||||||
raise IntentRecognitionError(
|
raise IntentRecognitionError(
|
||||||
@ -1126,10 +1147,6 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
speech: str = conversation_result.response.speech.get("plain", {}).get(
|
|
||||||
"speech", ""
|
|
||||||
)
|
|
||||||
|
|
||||||
return speech
|
return speech
|
||||||
|
|
||||||
async def prepare_text_to_speech(self) -> None:
|
async def prepare_text_to_speech(self) -> None:
|
||||||
|
@ -48,20 +48,25 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
|||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
|
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"HOME_ASSISTANT_AGENT",
|
"HOME_ASSISTANT_AGENT",
|
||||||
"OLD_HOME_ASSISTANT_AGENT",
|
"OLD_HOME_ASSISTANT_AGENT",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatSession",
|
||||||
"ConversationEntity",
|
"ConversationEntity",
|
||||||
"ConversationEntityFeature",
|
"ConversationEntityFeature",
|
||||||
"ConversationInput",
|
"ConversationInput",
|
||||||
"ConversationResult",
|
"ConversationResult",
|
||||||
"ConversationTraceEventType",
|
"ConversationTraceEventType",
|
||||||
|
"ConverseError",
|
||||||
"async_conversation_trace_append",
|
"async_conversation_trace_append",
|
||||||
"async_converse",
|
"async_converse",
|
||||||
"async_get_agent_info",
|
"async_get_agent_info",
|
||||||
|
"async_get_chat_session",
|
||||||
"async_set_agent",
|
"async_set_agent",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_unset_agent",
|
"async_unset_agent",
|
||||||
|
@ -62,6 +62,7 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
|
from .session import ChatMessage, async_get_chat_session
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -346,35 +347,52 @@ class DefaultAgent(ConversationEntity):
|
|||||||
|
|
||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
response: intent.IntentResponse | None = None
|
||||||
|
async with async_get_chat_session(self.hass, user_input) as chat_session:
|
||||||
|
# Check if a trigger matched
|
||||||
|
if trigger_result := await self.async_recognize_sentence_trigger(
|
||||||
|
user_input
|
||||||
|
):
|
||||||
|
# Process callbacks and get response
|
||||||
|
response_text = await self._handle_trigger_result(
|
||||||
|
trigger_result, user_input
|
||||||
|
)
|
||||||
|
|
||||||
# Check if a trigger matched
|
# Convert to conversation result
|
||||||
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
|
response = intent.IntentResponse(
|
||||||
# Process callbacks and get response
|
language=user_input.language or self.hass.config.language
|
||||||
response_text = await self._handle_trigger_result(
|
)
|
||||||
trigger_result, user_input
|
response.response_type = intent.IntentResponseType.ACTION_DONE
|
||||||
|
response.async_set_speech(response_text)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
# Match intents
|
||||||
|
intent_result = await self.async_recognize_intent(user_input)
|
||||||
|
response = await self._async_process_intent_result(
|
||||||
|
intent_result, user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||||
|
chat_session.async_add_message(
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content=speech,
|
||||||
|
native=response,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to conversation result
|
return ConversationResult(
|
||||||
response = intent.IntentResponse(
|
response=response, conversation_id=chat_session.conversation_id
|
||||||
language=user_input.language or self.hass.config.language
|
|
||||||
)
|
)
|
||||||
response.response_type = intent.IntentResponseType.ACTION_DONE
|
|
||||||
response.async_set_speech(response_text)
|
|
||||||
|
|
||||||
return ConversationResult(response=response)
|
|
||||||
|
|
||||||
# Match intents
|
|
||||||
intent_result = await self.async_recognize_intent(user_input)
|
|
||||||
return await self._async_process_intent_result(intent_result, user_input)
|
|
||||||
|
|
||||||
async def _async_process_intent_result(
|
async def _async_process_intent_result(
|
||||||
self,
|
self,
|
||||||
result: RecognizeResult | None,
|
result: RecognizeResult | None,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput,
|
||||||
) -> ConversationResult:
|
) -> intent.IntentResponse:
|
||||||
"""Process user input with intents."""
|
"""Process user input with intents."""
|
||||||
language = user_input.language or self.hass.config.language
|
language = user_input.language or self.hass.config.language
|
||||||
conversation_id = None # Not supported
|
|
||||||
|
|
||||||
# Intent match or failure
|
# Intent match or failure
|
||||||
lang_intents = await self.async_get_or_load_intents(language)
|
lang_intents = await self.async_get_or_load_intents(language)
|
||||||
@ -386,7 +404,6 @@ class DefaultAgent(ConversationEntity):
|
|||||||
language,
|
language,
|
||||||
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
|
||||||
self._get_error_text(ErrorKey.NO_INTENT, lang_intents),
|
self._get_error_text(ErrorKey.NO_INTENT, lang_intents),
|
||||||
conversation_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.unmatched_entities:
|
if result.unmatched_entities:
|
||||||
@ -408,7 +425,6 @@ class DefaultAgent(ConversationEntity):
|
|||||||
self._get_error_text(
|
self._get_error_text(
|
||||||
error_response_type, lang_intents, **error_response_args
|
error_response_type, lang_intents, **error_response_args
|
||||||
),
|
),
|
||||||
conversation_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Will never happen because result will be None when no intents are
|
# Will never happen because result will be None when no intents are
|
||||||
@ -461,7 +477,6 @@ class DefaultAgent(ConversationEntity):
|
|||||||
self._get_error_text(
|
self._get_error_text(
|
||||||
error_response_type, lang_intents, **error_response_args
|
error_response_type, lang_intents, **error_response_args
|
||||||
),
|
),
|
||||||
conversation_id,
|
|
||||||
)
|
)
|
||||||
except intent.IntentHandleError as err:
|
except intent.IntentHandleError as err:
|
||||||
# Intent was valid and entities matched constraints, but an error
|
# Intent was valid and entities matched constraints, but an error
|
||||||
@ -473,7 +488,6 @@ class DefaultAgent(ConversationEntity):
|
|||||||
self._get_error_text(
|
self._get_error_text(
|
||||||
err.response_key or ErrorKey.HANDLE_ERROR, lang_intents
|
err.response_key or ErrorKey.HANDLE_ERROR, lang_intents
|
||||||
),
|
),
|
||||||
conversation_id,
|
|
||||||
)
|
)
|
||||||
except intent.IntentUnexpectedError:
|
except intent.IntentUnexpectedError:
|
||||||
_LOGGER.exception("Unexpected intent error")
|
_LOGGER.exception("Unexpected intent error")
|
||||||
@ -481,7 +495,6 @@ class DefaultAgent(ConversationEntity):
|
|||||||
language,
|
language,
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
self._get_error_text(ErrorKey.HANDLE_ERROR, lang_intents),
|
self._get_error_text(ErrorKey.HANDLE_ERROR, lang_intents),
|
||||||
conversation_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -500,9 +513,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
)
|
)
|
||||||
intent_response.async_set_speech(speech)
|
intent_response.async_set_speech(speech)
|
||||||
|
|
||||||
return ConversationResult(
|
return intent_response
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _recognize(
|
def _recognize(
|
||||||
self,
|
self,
|
||||||
@ -1346,22 +1357,18 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# No error message on failed match
|
# No error message on failed match
|
||||||
return None
|
return None
|
||||||
|
|
||||||
conversation_result = await self._async_process_intent_result(
|
return await self._async_process_intent_result(result, user_input)
|
||||||
result, user_input
|
|
||||||
)
|
|
||||||
return conversation_result.response
|
|
||||||
|
|
||||||
|
|
||||||
def _make_error_result(
|
def _make_error_result(
|
||||||
language: str,
|
language: str,
|
||||||
error_code: intent.IntentResponseErrorCode,
|
error_code: intent.IntentResponseErrorCode,
|
||||||
response_text: str,
|
response_text: str,
|
||||||
conversation_id: str | None = None,
|
) -> intent.IntentResponse:
|
||||||
) -> ConversationResult:
|
|
||||||
"""Create conversation result with error code and text."""
|
"""Create conversation result with error code and text."""
|
||||||
response = intent.IntentResponse(language=language)
|
response = intent.IntentResponse(language=language)
|
||||||
response.async_set_error(error_code, response_text)
|
response.async_set_error(error_code, response_text)
|
||||||
return ConversationResult(response, conversation_id)
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str, Any]]:
|
def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str, Any]]:
|
||||||
|
327
homeassistant/components/conversation/session.py
Normal file
327
homeassistant/components/conversation/session.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
"""Conversation history."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
|
from typing import Generic, Literal, TypeVar
|
||||||
|
|
||||||
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
|
from homeassistant.core import (
|
||||||
|
CALLBACK_TYPE,
|
||||||
|
Event,
|
||||||
|
HassJob,
|
||||||
|
HassJobType,
|
||||||
|
HomeAssistant,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
|
from homeassistant.helpers import intent, llm, template
|
||||||
|
from homeassistant.helpers.event import async_call_later
|
||||||
|
from homeassistant.util import dt as dt_util, ulid
|
||||||
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .models import ConversationInput, ConversationResult
|
||||||
|
|
||||||
|
DATA_CHAT_HISTORY: HassKey["dict[str, ChatSession]"] = HassKey(
|
||||||
|
"conversation_chat_session"
|
||||||
|
)
|
||||||
|
DATA_CHAT_HISTORY_CLEANUP: HassKey["SessionCleanup"] = HassKey(
|
||||||
|
"conversation_chat_session_cleanup"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||||
|
_NativeT = TypeVar("_NativeT")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCleanup:
|
||||||
|
"""Helper to clean up the history."""
|
||||||
|
|
||||||
|
unsub: CALLBACK_TYPE | None = None
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
"""Initialize the history cleanup."""
|
||||||
|
self.hass = hass
|
||||||
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
||||||
|
self.cleanup_job = HassJob(
|
||||||
|
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def schedule(self) -> None:
|
||||||
|
"""Schedule the cleanup."""
|
||||||
|
if self.unsub:
|
||||||
|
return
|
||||||
|
self.unsub = async_call_later(
|
||||||
|
self.hass,
|
||||||
|
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
||||||
|
self.cleanup_job,
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _on_hass_stop(self, event: Event) -> None:
|
||||||
|
"""Cancel the cleanup on shutdown."""
|
||||||
|
if self.unsub:
|
||||||
|
self.unsub()
|
||||||
|
self.unsub = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _cleanup(self, now: datetime) -> None:
|
||||||
|
"""Clean up the history and schedule follow-up if necessary."""
|
||||||
|
self.unsub = None
|
||||||
|
all_history = self.hass.data[DATA_CHAT_HISTORY]
|
||||||
|
|
||||||
|
# We mutate original object because current commands could be
|
||||||
|
# yielding history based on it.
|
||||||
|
for conversation_id, history in list(all_history.items()):
|
||||||
|
if history.last_updated + CONVERSATION_TIMEOUT < now:
|
||||||
|
del all_history[conversation_id]
|
||||||
|
|
||||||
|
# Still conversations left, check again in timeout time.
|
||||||
|
if all_history:
|
||||||
|
self.schedule()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_get_chat_session(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
) -> AsyncGenerator["ChatSession"]:
|
||||||
|
"""Return chat session."""
|
||||||
|
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||||
|
if all_history is None:
|
||||||
|
all_history = {}
|
||||||
|
hass.data[DATA_CHAT_HISTORY] = all_history
|
||||||
|
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
|
||||||
|
|
||||||
|
history: ChatSession | None = None
|
||||||
|
|
||||||
|
if user_input.conversation_id is None:
|
||||||
|
conversation_id = ulid.ulid_now()
|
||||||
|
|
||||||
|
elif history := all_history.get(user_input.conversation_id):
|
||||||
|
conversation_id = user_input.conversation_id
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||||
|
# If an old OLID is passed in, we will generate a new one to indicate
|
||||||
|
# a new conversation was started. If the user picks their own, they
|
||||||
|
# want to track a conversation and we respect it.
|
||||||
|
try:
|
||||||
|
ulid.ulid_to_bytes(user_input.conversation_id)
|
||||||
|
conversation_id = ulid.ulid_now()
|
||||||
|
except ValueError:
|
||||||
|
conversation_id = user_input.conversation_id
|
||||||
|
|
||||||
|
if history:
|
||||||
|
history = replace(history, messages=history.messages.copy())
|
||||||
|
else:
|
||||||
|
history = ChatSession(hass, conversation_id)
|
||||||
|
|
||||||
|
message: ChatMessage = ChatMessage(
|
||||||
|
role="user",
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content=user_input.text,
|
||||||
|
)
|
||||||
|
history.async_add_message(message)
|
||||||
|
|
||||||
|
yield history
|
||||||
|
|
||||||
|
if history.messages[-1] is message:
|
||||||
|
LOGGER.debug(
|
||||||
|
"History opened but no assistant message was added, ignoring update"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
history.last_updated = dt_util.utcnow()
|
||||||
|
all_history[conversation_id] = history
|
||||||
|
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseError(HomeAssistantError):
|
||||||
|
"""Error during initialization of conversation.
|
||||||
|
|
||||||
|
Will not be stored in the history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, message: str, conversation_id: str, response: intent.IntentResponse
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the error."""
|
||||||
|
super().__init__(message)
|
||||||
|
self.conversation_id = conversation_id
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
def as_converstation_result(self) -> ConversationResult:
|
||||||
|
"""Return the error as a conversation result."""
|
||||||
|
return ConversationResult(
|
||||||
|
response=self.response,
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage(Generic[_NativeT]):
|
||||||
|
"""Base class for chat messages.
|
||||||
|
|
||||||
|
When role is native, the content is to be ignored and message
|
||||||
|
is only meant for storing the native object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["system", "assistant", "user", "native"]
|
||||||
|
agent_id: str | None
|
||||||
|
content: str
|
||||||
|
native: _NativeT | None = field(default=None)
|
||||||
|
|
||||||
|
# Validate in post-init that if role is native, there is no content and a native object exists
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate native message."""
|
||||||
|
if self.role == "native" and self.native is None:
|
||||||
|
raise ValueError("Native message must have a native object")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatSession(Generic[_NativeT]):
|
||||||
|
"""Class holding all information for a specific conversation."""
|
||||||
|
|
||||||
|
hass: HomeAssistant
|
||||||
|
conversation_id: str
|
||||||
|
user_name: str | None = None
|
||||||
|
messages: list[ChatMessage[_NativeT]] = field(
|
||||||
|
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
|
||||||
|
)
|
||||||
|
extra_system_prompt: str | None = None
|
||||||
|
llm_api: llm.APIInstance | None = None
|
||||||
|
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
|
||||||
|
"""Process intent."""
|
||||||
|
if message.role == "system":
|
||||||
|
raise ValueError("Cannot add system messages to history")
|
||||||
|
if message.role != "native" and self.messages[-1].role == message.role:
|
||||||
|
raise ValueError("Cannot add two assistant or user messages in a row")
|
||||||
|
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
|
||||||
|
"""Get messages for a specific agent ID.
|
||||||
|
|
||||||
|
This will filter out any native message tied to other agent IDs.
|
||||||
|
It can still include assistant/user messages generated by other agents.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
message
|
||||||
|
for message in self.messages
|
||||||
|
if message.role != "native" or message.agent_id == agent_id
|
||||||
|
]
|
||||||
|
|
||||||
|
async def async_process_llm_message(
|
||||||
|
self,
|
||||||
|
conversing_domain: str,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
user_llm_hass_api: str | None = None,
|
||||||
|
user_llm_prompt: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Process an incoming message for an LLM."""
|
||||||
|
llm_context = llm.LLMContext(
|
||||||
|
platform=conversing_domain,
|
||||||
|
context=user_input.context,
|
||||||
|
user_prompt=user_input.text,
|
||||||
|
language=user_input.language,
|
||||||
|
assistant=DOMAIN,
|
||||||
|
device_id=user_input.device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_api: llm.APIInstance | None = None
|
||||||
|
|
||||||
|
if user_llm_hass_api:
|
||||||
|
try:
|
||||||
|
llm_api = await llm.async_get_api(
|
||||||
|
self.hass,
|
||||||
|
user_llm_hass_api,
|
||||||
|
llm_context,
|
||||||
|
)
|
||||||
|
except HomeAssistantError as err:
|
||||||
|
LOGGER.error(
|
||||||
|
"Error getting LLM API %s for %s: %s",
|
||||||
|
user_llm_hass_api,
|
||||||
|
conversing_domain,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
intent_response.async_set_error(
|
||||||
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
|
"Error preparing LLM API",
|
||||||
|
)
|
||||||
|
raise ConverseError(
|
||||||
|
f"Error getting LLM API {user_llm_hass_api}",
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
response=intent_response,
|
||||||
|
) from err
|
||||||
|
|
||||||
|
user_name: str | None = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_input.context
|
||||||
|
and user_input.context.user_id
|
||||||
|
and (
|
||||||
|
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
user_name = user.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt_parts = [
|
||||||
|
template.Template(
|
||||||
|
llm.BASE_PROMPT
|
||||||
|
+ (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||||
|
self.hass,
|
||||||
|
).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
"user_name": user_name,
|
||||||
|
"llm_context": llm_context,
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
except TemplateError as err:
|
||||||
|
LOGGER.error("Error rendering prompt: %s", err)
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
intent_response.async_set_error(
|
||||||
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
|
"Sorry, I had a problem with my template",
|
||||||
|
)
|
||||||
|
raise ConverseError(
|
||||||
|
"Error rendering prompt",
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
response=intent_response,
|
||||||
|
) from err
|
||||||
|
|
||||||
|
if llm_api:
|
||||||
|
prompt_parts.append(llm_api.api_prompt)
|
||||||
|
|
||||||
|
extra_system_prompt = (
|
||||||
|
# Take new system prompt if one was given
|
||||||
|
user_input.extra_system_prompt or self.extra_system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
if extra_system_prompt:
|
||||||
|
prompt_parts.append(extra_system_prompt)
|
||||||
|
|
||||||
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
|
self.llm_api = llm_api
|
||||||
|
self.user_name = user_name
|
||||||
|
self.extra_system_prompt = extra_system_prompt
|
||||||
|
self.messages[0] = ChatMessage(
|
||||||
|
role="system",
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content=prompt,
|
||||||
|
)
|
@ -1,9 +1,8 @@
|
|||||||
"""Conversation support for OpenAI."""
|
"""Conversation support for OpenAI."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai._types import NOT_GIVEN
|
from openai._types import NOT_GIVEN
|
||||||
@ -12,10 +11,8 @@ from openai.types.chat import (
|
|||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionMessageToolCallParam,
|
ChatCompletionMessageToolCallParam,
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionToolMessageParam,
|
ChatCompletionToolMessageParam,
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
ChatCompletionUserMessageParam,
|
|
||||||
)
|
)
|
||||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||||
from openai.types.shared_params import FunctionDefinition
|
from openai.types.shared_params import FunctionDefinition
|
||||||
@ -27,10 +24,9 @@ from homeassistant.components.conversation import trace
|
|||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid
|
|
||||||
|
|
||||||
from . import OpenAIConfigEntry
|
from . import OpenAIConfigEntry
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -74,12 +70,28 @@ def _format_tool(
|
|||||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
|
||||||
class ChatHistory:
|
"""Convert from class to TypedDict."""
|
||||||
"""Class holding the chat history."""
|
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
||||||
|
if message.tool_calls:
|
||||||
extra_system_prompt: str | None = None
|
tool_calls = [
|
||||||
messages: list[ChatCompletionMessageParam] = field(default_factory=list)
|
ChatCompletionMessageToolCallParam(
|
||||||
|
id=tool_call.id,
|
||||||
|
function=Function(
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
name=tool_call.function.name,
|
||||||
|
),
|
||||||
|
type=tool_call.type,
|
||||||
|
)
|
||||||
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
param = ChatCompletionAssistantMessageParam(
|
||||||
|
role=message.role,
|
||||||
|
content=message.content,
|
||||||
|
)
|
||||||
|
if tool_calls:
|
||||||
|
param["tool_calls"] = tool_calls
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConversationEntity(
|
class OpenAIConversationEntity(
|
||||||
@ -93,7 +105,6 @@ class OpenAIConversationEntity(
|
|||||||
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, ChatHistory] = {}
|
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_unique_id = entry.entry_id
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
identifiers={(DOMAIN, entry.entry_id)},
|
identifiers={(DOMAIN, entry.entry_id)},
|
||||||
@ -132,127 +143,56 @@ class OpenAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
async with conversation.async_get_chat_session(
|
||||||
|
self.hass, user_input
|
||||||
|
) as session:
|
||||||
|
return await self._async_call_api(user_input, session)
|
||||||
|
|
||||||
|
async def _async_call_api(
|
||||||
|
self,
|
||||||
|
user_input: conversation.ConversationInput,
|
||||||
|
session: conversation.ChatSession[ChatCompletionMessageParam],
|
||||||
|
) -> conversation.ConversationResult:
|
||||||
|
"""Call the API."""
|
||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
llm_api: llm.APIInstance | None = None
|
|
||||||
tools: list[ChatCompletionToolParam] | None = None
|
|
||||||
user_name: str | None = None
|
|
||||||
llm_context = llm.LLMContext(
|
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if options.get(CONF_LLM_HASS_API):
|
|
||||||
try:
|
|
||||||
llm_api = await llm.async_get_api(
|
|
||||||
self.hass,
|
|
||||||
options[CONF_LLM_HASS_API],
|
|
||||||
llm_context,
|
|
||||||
)
|
|
||||||
except HomeAssistantError as err:
|
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
"Error preparing LLM API",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
|
||||||
)
|
|
||||||
tools = [
|
|
||||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
|
||||||
]
|
|
||||||
|
|
||||||
history: ChatHistory | None = None
|
|
||||||
|
|
||||||
if user_input.conversation_id is None:
|
|
||||||
conversation_id = ulid.ulid_now()
|
|
||||||
|
|
||||||
elif user_input.conversation_id in self.history:
|
|
||||||
conversation_id = user_input.conversation_id
|
|
||||||
history = self.history.get(conversation_id)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
|
||||||
# If an old OLID is passed in, we will generate a new one to indicate
|
|
||||||
# a new conversation was started. If the user picks their own, they
|
|
||||||
# want to track a conversation and we respect it.
|
|
||||||
try:
|
|
||||||
ulid.ulid_to_bytes(user_input.conversation_id)
|
|
||||||
conversation_id = ulid.ulid_now()
|
|
||||||
except ValueError:
|
|
||||||
conversation_id = user_input.conversation_id
|
|
||||||
|
|
||||||
if history is None:
|
|
||||||
history = ChatHistory(user_input.extra_system_prompt)
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_input.context
|
|
||||||
and user_input.context.user_id
|
|
||||||
and (
|
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
user_name = user.name
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_parts = [
|
await session.async_process_llm_message(
|
||||||
template.Template(
|
DOMAIN,
|
||||||
llm.BASE_PROMPT
|
user_input,
|
||||||
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
options.get(CONF_LLM_HASS_API),
|
||||||
self.hass,
|
options.get(CONF_PROMPT),
|
||||||
).async_render(
|
)
|
||||||
{
|
except conversation.ConverseError as err:
|
||||||
"ha_name": self.hass.config.location_name,
|
return err.as_converstation_result()
|
||||||
"user_name": user_name,
|
|
||||||
"llm_context": llm_context,
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
},
|
if session.llm_api:
|
||||||
parse_result=False,
|
tools = [
|
||||||
)
|
_format_tool(tool, session.llm_api.custom_serializer)
|
||||||
|
for tool in session.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
except TemplateError as err:
|
messages: list[ChatCompletionMessageParam] = []
|
||||||
LOGGER.error("Error rendering prompt: %s", err)
|
for message in session.async_get_messages(user_input.agent_id):
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
if message.native is not None and message.agent_id == user_input.agent_id:
|
||||||
intent_response.async_set_error(
|
messages.append(message.native)
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
else:
|
||||||
"Sorry, I had a problem with my template",
|
messages.append(
|
||||||
)
|
cast(
|
||||||
return conversation.ConversationResult(
|
ChatCompletionMessageParam,
|
||||||
response=intent_response, conversation_id=conversation_id
|
{"role": message.role, "content": message.content},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if llm_api:
|
LOGGER.debug("Prompt: %s", messages)
|
||||||
prompt_parts.append(llm_api.api_prompt)
|
|
||||||
|
|
||||||
extra_system_prompt = (
|
|
||||||
# Take new system prompt if one was given
|
|
||||||
user_input.extra_system_prompt or history.extra_system_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
if extra_system_prompt:
|
|
||||||
prompt_parts.append(extra_system_prompt)
|
|
||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
|
||||||
|
|
||||||
# Create a copy of the variable because we attach it to the trace
|
|
||||||
history = ChatHistory(
|
|
||||||
extra_system_prompt,
|
|
||||||
[
|
|
||||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
|
||||||
*history.messages[1:],
|
|
||||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER.debug("Prompt: %s", history.messages)
|
|
||||||
LOGGER.debug("Tools: %s", tools)
|
LOGGER.debug("Tools: %s", tools)
|
||||||
trace.async_conversation_trace_append(
|
trace.async_conversation_trace_append(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
{"messages": history.messages, "tools": llm_api.tools if llm_api else None},
|
{
|
||||||
|
"messages": session.messages,
|
||||||
|
"tools": session.llm_api.tools if session.llm_api else None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
@ -262,12 +202,12 @@ class OpenAIConversationEntity(
|
|||||||
try:
|
try:
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
messages=history.messages,
|
messages=messages,
|
||||||
tools=tools or NOT_GIVEN,
|
tools=tools or NOT_GIVEN,
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
user=conversation_id,
|
user=session.conversation_id,
|
||||||
)
|
)
|
||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||||
@ -277,44 +217,26 @@ class OpenAIConversationEntity(
|
|||||||
"Sorry, I had a problem talking to OpenAI",
|
"Sorry, I had a problem talking to OpenAI",
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=session.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
LOGGER.debug("Response %s", result)
|
LOGGER.debug("Response %s", result)
|
||||||
response = result.choices[0].message
|
response = result.choices[0].message
|
||||||
|
messages.append(_message_convert(response))
|
||||||
|
|
||||||
def message_convert(
|
session.async_add_message(
|
||||||
message: ChatCompletionMessage,
|
conversation.ChatMessage(
|
||||||
) -> ChatCompletionMessageParam:
|
role=response.role,
|
||||||
"""Convert from class to TypedDict."""
|
agent_id=user_input.agent_id,
|
||||||
tool_calls: list[ChatCompletionMessageToolCallParam] = []
|
content=response.content or "",
|
||||||
if message.tool_calls:
|
native=messages[-1],
|
||||||
tool_calls = [
|
),
|
||||||
ChatCompletionMessageToolCallParam(
|
)
|
||||||
id=tool_call.id,
|
|
||||||
function=Function(
|
|
||||||
arguments=tool_call.function.arguments,
|
|
||||||
name=tool_call.function.name,
|
|
||||||
),
|
|
||||||
type=tool_call.type,
|
|
||||||
)
|
|
||||||
for tool_call in message.tool_calls
|
|
||||||
]
|
|
||||||
param = ChatCompletionAssistantMessageParam(
|
|
||||||
role=message.role,
|
|
||||||
content=message.content,
|
|
||||||
)
|
|
||||||
if tool_calls:
|
|
||||||
param["tool_calls"] = tool_calls
|
|
||||||
return param
|
|
||||||
|
|
||||||
history.messages.append(message_convert(response))
|
if not response.tool_calls or not session.llm_api:
|
||||||
tool_calls = response.tool_calls
|
|
||||||
|
|
||||||
if not tool_calls or not llm_api:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name=tool_call.function.name,
|
tool_name=tool_call.function.name,
|
||||||
tool_args=json.loads(tool_call.function.arguments),
|
tool_args=json.loads(tool_call.function.arguments),
|
||||||
@ -324,27 +246,33 @@ class OpenAIConversationEntity(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_response = await llm_api.async_call_tool(tool_input)
|
tool_response = await session.llm_api.async_call_tool(tool_input)
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
except (HomeAssistantError, vol.Invalid) as e:
|
||||||
tool_response = {"error": type(e).__name__}
|
tool_response = {"error": type(e).__name__}
|
||||||
if str(e):
|
if str(e):
|
||||||
tool_response["error_text"] = str(e)
|
tool_response["error_text"] = str(e)
|
||||||
|
|
||||||
LOGGER.debug("Tool response: %s", tool_response)
|
LOGGER.debug("Tool response: %s", tool_response)
|
||||||
history.messages.append(
|
messages.append(
|
||||||
ChatCompletionToolMessageParam(
|
ChatCompletionToolMessageParam(
|
||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
content=json.dumps(tool_response),
|
content=json.dumps(tool_response),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
session.async_add_message(
|
||||||
self.history[conversation_id] = history
|
conversation.ChatMessage(
|
||||||
|
role="native",
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content="",
|
||||||
|
native=messages[-1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response.content or "")
|
intent_response.async_set_speech(response.content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=session.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
|
@ -15,10 +15,6 @@ import voluptuous as vol
|
|||||||
from voluptuous_openapi import UNSUPPORTED, convert
|
from voluptuous_openapi import UNSUPPORTED, convert
|
||||||
|
|
||||||
from homeassistant.components.climate import INTENT_GET_TEMPERATURE
|
from homeassistant.components.climate import INTENT_GET_TEMPERATURE
|
||||||
from homeassistant.components.conversation import (
|
|
||||||
ConversationTraceEventType,
|
|
||||||
async_conversation_trace_append,
|
|
||||||
)
|
|
||||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||||
from homeassistant.components.homeassistant import async_should_expose
|
from homeassistant.components.homeassistant import async_should_expose
|
||||||
from homeassistant.components.intent import async_device_supports_timers
|
from homeassistant.components.intent import async_device_supports_timers
|
||||||
@ -171,6 +167,12 @@ class APIInstance:
|
|||||||
|
|
||||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||||
"""Call a LLM tool, validate args and return the response."""
|
"""Call a LLM tool, validate args and return the response."""
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from homeassistant.components.conversation import (
|
||||||
|
ConversationTraceEventType,
|
||||||
|
async_conversation_trace_append,
|
||||||
|
)
|
||||||
|
|
||||||
async_conversation_trace_append(
|
async_conversation_trace_append(
|
||||||
ConversationTraceEventType.TOOL_CALL,
|
ConversationTraceEventType.TOOL_CALL,
|
||||||
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
||||||
|
@ -44,7 +44,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -135,7 +135,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -226,7 +226,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -341,7 +341,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -446,7 +446,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -497,7 +497,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -548,7 +548,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
|
@ -42,7 +42,7 @@
|
|||||||
# name: test_audio_pipeline.4
|
# name: test_audio_pipeline.4
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -125,7 +125,7 @@
|
|||||||
# name: test_audio_pipeline_debug.4
|
# name: test_audio_pipeline_debug.4
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -220,7 +220,7 @@
|
|||||||
# name: test_audio_pipeline_with_enhancements.4
|
# name: test_audio_pipeline_with_enhancements.4
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -325,7 +325,7 @@
|
|||||||
# name: test_audio_pipeline_with_wake_word_no_timeout.6
|
# name: test_audio_pipeline_with_wake_word_no_timeout.6
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -585,7 +585,7 @@
|
|||||||
# name: test_pipeline_empty_tts_output.2
|
# name: test_pipeline_empty_tts_output.2
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -698,7 +698,7 @@
|
|||||||
# name: test_text_only_pipeline[extra_msg0].2
|
# name: test_text_only_pipeline[extra_msg0].2
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -744,7 +744,7 @@
|
|||||||
# name: test_text_only_pipeline[extra_msg1].2
|
# name: test_text_only_pipeline[extra_msg1].2
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_custom_sentences
|
# name: test_custom_sentences
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -26,7 +26,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_custom_sentences.1
|
# name: test_custom_sentences.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -51,7 +51,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_custom_sentences_config
|
# name: test_custom_sentences_config
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -76,7 +76,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_alias_added_removed
|
# name: test_intent_alias_added_removed
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -106,7 +106,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_alias_added_removed.1
|
# name: test_intent_alias_added_removed.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -136,7 +136,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_alias_added_removed.2
|
# name: test_intent_alias_added_removed.2
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -156,7 +156,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_conversion_not_expose_new
|
# name: test_intent_conversion_not_expose_new
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -176,7 +176,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_conversion_not_expose_new.1
|
# name: test_intent_conversion_not_expose_new.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -206,7 +206,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_added_removed
|
# name: test_intent_entity_added_removed
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -236,7 +236,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_added_removed.1
|
# name: test_intent_entity_added_removed.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -266,7 +266,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_added_removed.2
|
# name: test_intent_entity_added_removed.2
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -296,7 +296,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_added_removed.3
|
# name: test_intent_entity_added_removed.3
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -316,7 +316,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_exposed
|
# name: test_intent_entity_exposed
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -346,7 +346,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_fail_if_unexposed
|
# name: test_intent_entity_fail_if_unexposed
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -366,7 +366,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_remove_custom_name
|
# name: test_intent_entity_remove_custom_name
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -386,7 +386,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_remove_custom_name.1
|
# name: test_intent_entity_remove_custom_name.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -416,7 +416,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_remove_custom_name.2
|
# name: test_intent_entity_remove_custom_name.2
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -436,7 +436,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_renamed
|
# name: test_intent_entity_renamed
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -466,7 +466,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_entity_renamed.1
|
# name: test_intent_entity_renamed.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
|
@ -201,7 +201,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_http_api_handle_failure
|
# name: test_http_api_handle_failure
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -221,7 +221,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_http_api_no_match
|
# name: test_http_api_no_match
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -241,7 +241,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_http_api_unexpected_failure
|
# name: test_http_api_unexpected_failure
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -261,7 +261,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_http_processing_intent[None]
|
# name: test_http_processing_intent[None]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -291,7 +291,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_http_processing_intent[conversation.home_assistant]
|
# name: test_http_processing_intent[conversation.home_assistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -321,7 +321,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_http_processing_intent[homeassistant]
|
# name: test_http_processing_intent[homeassistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -351,7 +351,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_api[payload0]
|
# name: test_ws_api[payload0]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -371,7 +371,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_api[payload1]
|
# name: test_ws_api[payload1]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -391,7 +391,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_api[payload2]
|
# name: test_ws_api[payload2]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -411,7 +411,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_api[payload3]
|
# name: test_ws_api[payload3]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -431,7 +431,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_api[payload4]
|
# name: test_ws_api[payload4]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -451,7 +451,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_ws_api[payload5]
|
# name: test_ws_api[payload5]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_custom_agent
|
# name: test_custom_agent
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': 'test-conv-id',
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -44,7 +44,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn kitchen on-None]
|
# name: test_turn_on_intent[None-turn kitchen on-None]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -74,7 +74,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
|
# name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -104,7 +104,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn kitchen on-homeassistant]
|
# name: test_turn_on_intent[None-turn kitchen on-homeassistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -134,7 +134,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn on kitchen-None]
|
# name: test_turn_on_intent[None-turn on kitchen-None]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -164,7 +164,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
|
# name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -194,7 +194,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[None-turn on kitchen-homeassistant]
|
# name: test_turn_on_intent[None-turn on kitchen-homeassistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -224,7 +224,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-None]
|
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-None]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -254,7 +254,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
|
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -284,7 +284,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
|
# name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -314,7 +314,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-None]
|
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-None]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -344,7 +344,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
|
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
@ -374,7 +374,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
|
# name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant]
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': <ANY>,
|
||||||
'response': dict({
|
'response': dict({
|
||||||
'card': dict({
|
'card': dict({
|
||||||
}),
|
}),
|
||||||
|
171
tests/components/conversation/test_session.py
Normal file
171
tests/components/conversation/test_session.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
"""Test the conversation session."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import timedelta
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.conversation import ConversationInput, session
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from tests.common import async_fire_time_changed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
|
||||||
|
"""Return a conversation input instance."""
|
||||||
|
return ConversationInput(
|
||||||
|
text="Hello",
|
||||||
|
context=Context(),
|
||||||
|
conversation_id=None,
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
device_id=None,
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ulid() -> Generator[Mock]:
|
||||||
|
"""Mock the ulid library."""
|
||||||
|
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||||
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("start_id", "given_id"),
|
||||||
|
[
|
||||||
|
(None, "mock-ulid"),
|
||||||
|
# This ULID is not known as a session
|
||||||
|
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
|
||||||
|
("not-a-ulid", "not-a-ulid"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_conversation_id(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_conversation_input: ConversationInput,
|
||||||
|
mock_ulid: Mock,
|
||||||
|
start_id: str | None,
|
||||||
|
given_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test conversation ID generation."""
|
||||||
|
mock_conversation_input.conversation_id = start_id
|
||||||
|
async with session.async_get_chat_session(
|
||||||
|
hass, mock_conversation_input
|
||||||
|
) as chat_session:
|
||||||
|
assert chat_session.conversation_id == given_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cleanup(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_conversation_input: ConversationInput,
|
||||||
|
) -> None:
|
||||||
|
"""Mock cleanup of the conversation session."""
|
||||||
|
async with session.async_get_chat_session(
|
||||||
|
hass, mock_conversation_input
|
||||||
|
) as chat_session:
|
||||||
|
assert len(chat_session.messages) == 2
|
||||||
|
conversation_id = chat_session.conversation_id
|
||||||
|
|
||||||
|
# Generate session entry.
|
||||||
|
async with session.async_get_chat_session(
|
||||||
|
hass, mock_conversation_input
|
||||||
|
) as chat_session:
|
||||||
|
# Because we didn't add a message to the session in the last block,
|
||||||
|
# the conversation was not be persisted and we get a new ID
|
||||||
|
assert chat_session.conversation_id != conversation_id
|
||||||
|
conversation_id = chat_session.conversation_id
|
||||||
|
chat_session.async_add_message(
|
||||||
|
session.ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
content="Hey!",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(chat_session.messages) == 3
|
||||||
|
|
||||||
|
# Reuse conversation ID to ensure we can chat with same session
|
||||||
|
mock_conversation_input.conversation_id = conversation_id
|
||||||
|
async with session.async_get_chat_session(
|
||||||
|
hass, mock_conversation_input
|
||||||
|
) as chat_session:
|
||||||
|
assert len(chat_session.messages) == 4
|
||||||
|
assert chat_session.conversation_id == conversation_id
|
||||||
|
|
||||||
|
async_fire_time_changed(
|
||||||
|
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# It should be cleaned up now and we start a new conversation
|
||||||
|
async with session.async_get_chat_session(
|
||||||
|
hass, mock_conversation_input
|
||||||
|
) as chat_session:
|
||||||
|
assert chat_session.conversation_id != conversation_id
|
||||||
|
assert len(chat_session.messages) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_message_filtering(
|
||||||
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
|
) -> None:
|
||||||
|
"""Test filtering of messages."""
|
||||||
|
async with session.async_get_chat_session(
|
||||||
|
hass, mock_conversation_input
|
||||||
|
) as chat_session:
|
||||||
|
messages = chat_session.async_get_messages(agent_id=None)
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert messages[0] == session.ChatMessage(
|
||||||
|
role="system",
|
||||||
|
agent_id=None,
|
||||||
|
content="",
|
||||||
|
)
|
||||||
|
assert messages[1] == session.ChatMessage(
|
||||||
|
role="user",
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
content=mock_conversation_input.text,
|
||||||
|
)
|
||||||
|
# Cannot add a second user message in a row
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat_session.async_add_message(
|
||||||
|
session.ChatMessage(
|
||||||
|
role="user",
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
content="Hey!",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_session.async_add_message(
|
||||||
|
session.ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
content="Hey!",
|
||||||
|
native="assistant-reply-native",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Different agent, will be filtered out.
|
||||||
|
chat_session.async_add_message(
|
||||||
|
session.ChatMessage(
|
||||||
|
role="native", agent_id="another-mock-agent-id", content="", native=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chat_session.async_add_message(
|
||||||
|
session.ChatMessage(
|
||||||
|
role="native", agent_id="mock-agent-id", content="", native=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chat_session.messages) == 5
|
||||||
|
|
||||||
|
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
||||||
|
assert len(messages) == 4
|
||||||
|
|
||||||
|
assert messages[2] == session.ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
agent_id="mock-agent-id",
|
||||||
|
content="Hey!",
|
||||||
|
native="assistant-reply-native",
|
||||||
|
)
|
||||||
|
assert messages[3] == session.ChatMessage(
|
||||||
|
role="native", agent_id="mock-agent-id", content="", native=1
|
||||||
|
)
|
@ -1,7 +1,7 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_unknown_hass_api
|
# name: test_unknown_hass_api
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'my-conversation-id',
|
||||||
'response': IntentResponse(
|
'response': IntentResponse(
|
||||||
card=dict({
|
card=dict({
|
||||||
}),
|
}),
|
||||||
|
@ -625,7 +625,11 @@ async def test_unknown_hass_api(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
hass,
|
||||||
|
"hello",
|
||||||
|
"my-conversation-id",
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == snapshot
|
assert result == snapshot
|
||||||
|
@ -109,6 +109,8 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
|
|||||||
serializable_data = cls._serializable_issue_registry_entry(data)
|
serializable_data = cls._serializable_issue_registry_entry(data)
|
||||||
elif isinstance(data, dict) and "flow_id" in data and "handler" in data:
|
elif isinstance(data, dict) and "flow_id" in data and "handler" in data:
|
||||||
serializable_data = cls._serializable_flow_result(data)
|
serializable_data = cls._serializable_flow_result(data)
|
||||||
|
elif isinstance(data, dict) and set(data) == {"conversation_id", "response"}:
|
||||||
|
serializable_data = cls._serializable_conversation_result(data)
|
||||||
elif isinstance(data, vol.Schema):
|
elif isinstance(data, vol.Schema):
|
||||||
serializable_data = voluptuous_serialize.convert(data)
|
serializable_data = voluptuous_serialize.convert(data)
|
||||||
elif isinstance(data, ConfigEntry):
|
elif isinstance(data, ConfigEntry):
|
||||||
@ -200,6 +202,11 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer):
|
|||||||
"""Prepare a Home Assistant flow result for serialization."""
|
"""Prepare a Home Assistant flow result for serialization."""
|
||||||
return FlowResultSnapshot(data | {"flow_id": ANY})
|
return FlowResultSnapshot(data | {"flow_id": ANY})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _serializable_conversation_result(cls, data: dict) -> SerializableData:
|
||||||
|
"""Prepare a Home Assistant conversation result for serialization."""
|
||||||
|
return data | {"conversation_id": ANY}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _serializable_issue_registry_entry(
|
def _serializable_issue_registry_entry(
|
||||||
cls, data: ir.IssueEntry
|
cls, data: ir.IssueEntry
|
||||||
|
Loading…
x
Reference in New Issue
Block a user