ChatSession: Split native content out of message class (#136668)

Split native content out of message class
This commit is contained in:
Paulus Schoutsen 2025-01-28 00:12:42 -05:00 committed by GitHub
parent 48a91540e1
commit 5690516852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 59 additions and 73 deletions

View File

@ -1101,11 +1101,10 @@ class PipelineRun:
"speech", "" "speech", ""
) )
chat_session.async_add_message( chat_session.async_add_message(
conversation.ChatMessage( conversation.Content(
role="assistant", role="assistant",
agent_id=agent_id, agent_id=agent_id,
content=speech, content=speech,
native=intent_response,
) )
) )
conversation_result = conversation.ConversationResult( conversation_result = conversation.ConversationResult(

View File

@ -48,21 +48,28 @@ from .default_agent import DefaultAgent, async_setup_default_agent
from .entity import ConversationEntity from .entity import ConversationEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_conversation_http
from .models import AbstractConversationAgent, ConversationInput, ConversationResult from .models import AbstractConversationAgent, ConversationInput, ConversationResult
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session from .session import (
ChatSession,
Content,
ConverseError,
NativeContent,
async_get_chat_session,
)
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"HOME_ASSISTANT_AGENT", "HOME_ASSISTANT_AGENT",
"OLD_HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT",
"ChatMessage",
"ChatSession", "ChatSession",
"Content",
"ConversationEntity", "ConversationEntity",
"ConversationEntityFeature", "ConversationEntityFeature",
"ConversationInput", "ConversationInput",
"ConversationResult", "ConversationResult",
"ConversationTraceEventType", "ConversationTraceEventType",
"ConverseError", "ConverseError",
"NativeContent",
"async_conversation_trace_append", "async_conversation_trace_append",
"async_converse", "async_converse",
"async_get_agent_info", "async_get_agent_info",

View File

@ -62,7 +62,7 @@ from .const import (
) )
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
from .session import ChatMessage, async_get_chat_session from .session import Content, async_get_chat_session
from .trace import ConversationTraceEventType, async_conversation_trace_append from .trace import ConversationTraceEventType, async_conversation_trace_append
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -374,11 +374,10 @@ class DefaultAgent(ConversationEntity):
speech: str = response.speech.get("plain", {}).get("speech", "") speech: str = response.speech.get("plain", {}).get("speech", "")
chat_session.async_add_message( chat_session.async_add_message(
ChatMessage( Content(
role="assistant", role="assistant",
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=speech, content=speech,
native=response,
) )
) )

View File

@ -126,7 +126,7 @@ async def async_get_chat_session(
else: else:
history = ChatSession(hass, conversation_id, user_input.agent_id) history = ChatSession(hass, conversation_id, user_input.agent_id)
message: ChatMessage = ChatMessage( message: Content = Content(
role="user", role="user",
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=user_input.text, content=user_input.text,
@ -169,23 +169,21 @@ class ConverseError(HomeAssistantError):
@dataclass @dataclass
class ChatMessage[_NativeT]: class Content:
"""Base class for chat messages. """Base class for chat messages."""
When role is native, the content is to be ignored and message role: Literal["system", "assistant", "user"]
is only meant for storing the native object.
"""
role: Literal["system", "assistant", "user", "native"]
agent_id: str | None agent_id: str | None
content: str content: str
native: _NativeT | None = field(default=None)
# Validate in post-init that if role is native, there is no content and a native object exists
def __post_init__(self) -> None: @dataclass(frozen=True)
"""Validate native message.""" class NativeContent[_NativeT]:
if self.role == "native" and self.native is None: """Native content."""
raise ValueError("Native message must have a native object")
role: str = field(init=False, default="native")
agent_id: str
content: _NativeT
@dataclass @dataclass
@ -196,15 +194,15 @@ class ChatSession[_NativeT]:
conversation_id: str conversation_id: str
agent_id: str | None agent_id: str | None
user_name: str | None = None user_name: str | None = None
messages: list[ChatMessage[_NativeT]] = field( messages: list[Content | NativeContent[_NativeT]] = field(
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")] default_factory=lambda: [Content(role="system", agent_id=None, content="")]
) )
extra_system_prompt: str | None = None extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
last_updated: datetime = field(default_factory=dt_util.utcnow) last_updated: datetime = field(default_factory=dt_util.utcnow)
@callback @callback
def async_add_message(self, message: ChatMessage[_NativeT]) -> None: def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
"""Process intent.""" """Process intent."""
if message.role == "system": if message.role == "system":
raise ValueError("Cannot add system messages to history") raise ValueError("Cannot add system messages to history")
@ -216,7 +214,7 @@ class ChatSession[_NativeT]:
@callback @callback
def async_get_messages( def async_get_messages(
self, agent_id: str | None = None self, agent_id: str | None = None
) -> list[ChatMessage[_NativeT]]: ) -> list[Content | NativeContent[_NativeT]]:
"""Get messages for a specific agent ID. """Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs. This will filter out any native message tied to other agent IDs.
@ -328,7 +326,7 @@ class ChatSession[_NativeT]:
self.llm_api = llm_api self.llm_api = llm_api
self.user_name = user_name self.user_name = user_name
self.extra_system_prompt = extra_system_prompt self.extra_system_prompt = extra_system_prompt
self.messages[0] = ChatMessage( self.messages[0] = Content(
role="system", role="system",
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=prompt, content=prompt,

View File

@ -93,12 +93,13 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
def _chat_message_convert( def _chat_message_convert(
message: conversation.ChatMessage[ChatCompletionMessageParam], message: conversation.Content
agent_id: str | None, | conversation.NativeContent[ChatCompletionMessageParam],
) -> ChatCompletionMessageParam: ) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format.""" """Convert any native chat message for this agent to the native format."""
if message.native is not None and message.agent_id == agent_id: if message.role == "native":
return message.native # mypy doesn't understand that checking role ensures content type
return message.content # type: ignore[return-value]
return cast( return cast(
ChatCompletionMessageParam, ChatCompletionMessageParam,
{"role": message.role, "content": message.content}, {"role": message.role, "content": message.content},
@ -157,14 +158,15 @@ class OpenAIConversationEntity(
async with conversation.async_get_chat_session( async with conversation.async_get_chat_session(
self.hass, user_input self.hass, user_input
) as session: ) as session:
return await self._async_call_api(user_input, session) return await self._async_handle_message(user_input, session)
async def _async_call_api( async def _async_handle_message(
self, self,
user_input: conversation.ConversationInput, user_input: conversation.ConversationInput,
session: conversation.ChatSession[ChatCompletionMessageParam], session: conversation.ChatSession[ChatCompletionMessageParam],
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
assert user_input.agent_id
options = self.entry.options options = self.entry.options
try: try:
@ -185,8 +187,7 @@ class OpenAIConversationEntity(
] ]
messages = [ messages = [
_chat_message_convert(message, user_input.agent_id) _chat_message_convert(message) for message in session.async_get_messages()
for message in session.async_get_messages()
] ]
client = self.entry.runtime_data client = self.entry.runtime_data
@ -212,11 +213,10 @@ class OpenAIConversationEntity(
messages.append(_message_convert(response)) messages.append(_message_convert(response))
session.async_add_message( session.async_add_message(
conversation.ChatMessage( conversation.Content(
role=response.role, role=response.role,
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content=response.content or "", content=response.content or "",
native=messages[-1],
), ),
) )
@ -237,11 +237,9 @@ class OpenAIConversationEntity(
) )
) )
session.async_add_message( session.async_add_message(
conversation.ChatMessage( conversation.NativeContent(
role="native",
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
content="", content=messages[-1],
native=messages[-1],
) )
) )

View File

@ -82,7 +82,7 @@ async def test_cleanup(
assert chat_session.conversation_id != conversation_id assert chat_session.conversation_id != conversation_id
conversation_id = chat_session.conversation_id conversation_id = chat_session.conversation_id
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.Content(
role="assistant", role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
@ -127,12 +127,6 @@ async def test_cleanup(
assert len(chat_session.messages) == 2 assert len(chat_session.messages) == 2
def test_chat_message() -> None:
"""Test chat message."""
with pytest.raises(ValueError):
session.ChatMessage(role="native", agent_id=None, content="", native=None)
async def test_add_message( async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None: ) -> None:
@ -144,7 +138,7 @@ async def test_add_message(
with pytest.raises(ValueError): with pytest.raises(ValueError):
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage(role="system", agent_id=None, content="") session.Content(role="system", agent_id=None, content="")
) )
# No 2 user messages in a row # No 2 user messages in a row
@ -152,19 +146,19 @@ async def test_add_message(
with pytest.raises(ValueError): with pytest.raises(ValueError):
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage(role="user", agent_id=None, content="") session.Content(role="user", agent_id=None, content="")
) )
# No 2 assistant messages in a row # No 2 assistant messages in a row
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="") session.Content(role="assistant", agent_id=None, content="")
) )
assert len(chat_session.messages) == 3 assert len(chat_session.messages) == 3
assert chat_session.messages[-1].role == "assistant" assert chat_session.messages[-1].role == "assistant"
with pytest.raises(ValueError): with pytest.raises(ValueError):
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="") session.Content(role="assistant", agent_id=None, content="")
) )
@ -177,12 +171,12 @@ async def test_message_filtering(
) as chat_session: ) as chat_session:
messages = chat_session.async_get_messages(agent_id=None) messages = chat_session.async_get_messages(agent_id=None)
assert len(messages) == 2 assert len(messages) == 2
assert messages[0] == session.ChatMessage( assert messages[0] == session.Content(
role="system", role="system",
agent_id=None, agent_id=None,
content="", content="",
) )
assert messages[1] == session.ChatMessage( assert messages[1] == session.Content(
role="user", role="user",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content=mock_conversation_input.text, content=mock_conversation_input.text,
@ -190,7 +184,7 @@ async def test_message_filtering(
# Cannot add a second user message in a row # Cannot add a second user message in a row
with pytest.raises(ValueError): with pytest.raises(ValueError):
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.Content(
role="user", role="user",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
@ -198,31 +192,25 @@ async def test_message_filtering(
) )
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.Content(
role="assistant", role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
native="assistant-reply-native",
) )
) )
# Different agent, native messages will be filtered out. # Different agent, native messages will be filtered out.
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.NativeContent(agent_id="another-mock-agent-id", content=1)
role="native", agent_id="another-mock-agent-id", content="", native=1
)
) )
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.NativeContent(agent_id="mock-agent-id", content=1)
role="native", agent_id="mock-agent-id", content="", native=1
)
) )
# A non-native message from another agent is not filtered out. # A non-native message from another agent is not filtered out.
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.Content(
role="assistant", role="assistant",
agent_id="another-mock-agent-id", agent_id="another-mock-agent-id",
content="Hi!", content="Hi!",
native=1,
) )
) )
@ -231,17 +219,14 @@ async def test_message_filtering(
messages = chat_session.async_get_messages(agent_id="mock-agent-id") messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 5 assert len(messages) == 5
assert messages[2] == session.ChatMessage( assert messages[2] == session.Content(
role="assistant", role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
native="assistant-reply-native",
) )
assert messages[3] == session.ChatMessage( assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
role="native", agent_id="mock-agent-id", content="", native=1 assert messages[4] == session.Content(
) role="assistant", agent_id="another-mock-agent-id", content="Hi!"
assert messages[4] == session.ChatMessage(
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
) )
@ -361,7 +346,7 @@ async def test_extra_systen_prompt(
user_llm_prompt=None, user_llm_prompt=None,
) )
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.Content(
role="assistant", role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",
@ -401,7 +386,7 @@ async def test_extra_systen_prompt(
user_llm_prompt=None, user_llm_prompt=None,
) )
chat_session.async_add_message( chat_session.async_add_message(
session.ChatMessage( session.Content(
role="assistant", role="assistant",
agent_id="mock-agent-id", agent_id="mock-agent-id",
content="Hey!", content="Hey!",