mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
ChatSession: Split native content out of message class (#136668)
Split native content out of message class
This commit is contained in:
parent
48a91540e1
commit
5690516852
@ -1101,11 +1101,10 @@ class PipelineRun:
|
|||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
conversation.ChatMessage(
|
conversation.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
content=speech,
|
content=speech,
|
||||||
native=intent_response,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conversation_result = conversation.ConversationResult(
|
conversation_result = conversation.ConversationResult(
|
||||||
|
@ -48,21 +48,28 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
|||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_conversation_http
|
||||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||||
from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session
|
from .session import (
|
||||||
|
ChatSession,
|
||||||
|
Content,
|
||||||
|
ConverseError,
|
||||||
|
NativeContent,
|
||||||
|
async_get_chat_session,
|
||||||
|
)
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"HOME_ASSISTANT_AGENT",
|
"HOME_ASSISTANT_AGENT",
|
||||||
"OLD_HOME_ASSISTANT_AGENT",
|
"OLD_HOME_ASSISTANT_AGENT",
|
||||||
"ChatMessage",
|
|
||||||
"ChatSession",
|
"ChatSession",
|
||||||
|
"Content",
|
||||||
"ConversationEntity",
|
"ConversationEntity",
|
||||||
"ConversationEntityFeature",
|
"ConversationEntityFeature",
|
||||||
"ConversationInput",
|
"ConversationInput",
|
||||||
"ConversationResult",
|
"ConversationResult",
|
||||||
"ConversationTraceEventType",
|
"ConversationTraceEventType",
|
||||||
"ConverseError",
|
"ConverseError",
|
||||||
|
"NativeContent",
|
||||||
"async_conversation_trace_append",
|
"async_conversation_trace_append",
|
||||||
"async_converse",
|
"async_converse",
|
||||||
"async_get_agent_info",
|
"async_get_agent_info",
|
||||||
|
@ -62,7 +62,7 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
from .session import ChatMessage, async_get_chat_session
|
from .session import Content, async_get_chat_session
|
||||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -374,11 +374,10 @@ class DefaultAgent(ConversationEntity):
|
|||||||
|
|
||||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
ChatMessage(
|
Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content=speech,
|
content=speech,
|
||||||
native=response,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ async def async_get_chat_session(
|
|||||||
else:
|
else:
|
||||||
history = ChatSession(hass, conversation_id, user_input.agent_id)
|
history = ChatSession(hass, conversation_id, user_input.agent_id)
|
||||||
|
|
||||||
message: ChatMessage = ChatMessage(
|
message: Content = Content(
|
||||||
role="user",
|
role="user",
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content=user_input.text,
|
content=user_input.text,
|
||||||
@ -169,23 +169,21 @@ class ConverseError(HomeAssistantError):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessage[_NativeT]:
|
class Content:
|
||||||
"""Base class for chat messages.
|
"""Base class for chat messages."""
|
||||||
|
|
||||||
When role is native, the content is to be ignored and message
|
role: Literal["system", "assistant", "user"]
|
||||||
is only meant for storing the native object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
role: Literal["system", "assistant", "user", "native"]
|
|
||||||
agent_id: str | None
|
agent_id: str | None
|
||||||
content: str
|
content: str
|
||||||
native: _NativeT | None = field(default=None)
|
|
||||||
|
|
||||||
# Validate in post-init that if role is native, there is no content and a native object exists
|
|
||||||
def __post_init__(self) -> None:
|
@dataclass(frozen=True)
|
||||||
"""Validate native message."""
|
class NativeContent[_NativeT]:
|
||||||
if self.role == "native" and self.native is None:
|
"""Native content."""
|
||||||
raise ValueError("Native message must have a native object")
|
|
||||||
|
role: str = field(init=False, default="native")
|
||||||
|
agent_id: str
|
||||||
|
content: _NativeT
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -196,15 +194,15 @@ class ChatSession[_NativeT]:
|
|||||||
conversation_id: str
|
conversation_id: str
|
||||||
agent_id: str | None
|
agent_id: str | None
|
||||||
user_name: str | None = None
|
user_name: str | None = None
|
||||||
messages: list[ChatMessage[_NativeT]] = field(
|
messages: list[Content | NativeContent[_NativeT]] = field(
|
||||||
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
|
default_factory=lambda: [Content(role="system", agent_id=None, content="")]
|
||||||
)
|
)
|
||||||
extra_system_prompt: str | None = None
|
extra_system_prompt: str | None = None
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
|
def async_add_message(self, message: Content | NativeContent[_NativeT]) -> None:
|
||||||
"""Process intent."""
|
"""Process intent."""
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
raise ValueError("Cannot add system messages to history")
|
raise ValueError("Cannot add system messages to history")
|
||||||
@ -216,7 +214,7 @@ class ChatSession[_NativeT]:
|
|||||||
@callback
|
@callback
|
||||||
def async_get_messages(
|
def async_get_messages(
|
||||||
self, agent_id: str | None = None
|
self, agent_id: str | None = None
|
||||||
) -> list[ChatMessage[_NativeT]]:
|
) -> list[Content | NativeContent[_NativeT]]:
|
||||||
"""Get messages for a specific agent ID.
|
"""Get messages for a specific agent ID.
|
||||||
|
|
||||||
This will filter out any native message tied to other agent IDs.
|
This will filter out any native message tied to other agent IDs.
|
||||||
@ -328,7 +326,7 @@ class ChatSession[_NativeT]:
|
|||||||
self.llm_api = llm_api
|
self.llm_api = llm_api
|
||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.extra_system_prompt = extra_system_prompt
|
self.extra_system_prompt = extra_system_prompt
|
||||||
self.messages[0] = ChatMessage(
|
self.messages[0] = Content(
|
||||||
role="system",
|
role="system",
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content=prompt,
|
content=prompt,
|
||||||
|
@ -93,12 +93,13 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
|
|||||||
|
|
||||||
|
|
||||||
def _chat_message_convert(
|
def _chat_message_convert(
|
||||||
message: conversation.ChatMessage[ChatCompletionMessageParam],
|
message: conversation.Content
|
||||||
agent_id: str | None,
|
| conversation.NativeContent[ChatCompletionMessageParam],
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
if message.native is not None and message.agent_id == agent_id:
|
if message.role == "native":
|
||||||
return message.native
|
# mypy doesn't understand that checking role ensures content type
|
||||||
|
return message.content # type: ignore[return-value]
|
||||||
return cast(
|
return cast(
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
{"role": message.role, "content": message.content},
|
{"role": message.role, "content": message.content},
|
||||||
@ -157,14 +158,15 @@ class OpenAIConversationEntity(
|
|||||||
async with conversation.async_get_chat_session(
|
async with conversation.async_get_chat_session(
|
||||||
self.hass, user_input
|
self.hass, user_input
|
||||||
) as session:
|
) as session:
|
||||||
return await self._async_call_api(user_input, session)
|
return await self._async_handle_message(user_input, session)
|
||||||
|
|
||||||
async def _async_call_api(
|
async def _async_handle_message(
|
||||||
self,
|
self,
|
||||||
user_input: conversation.ConversationInput,
|
user_input: conversation.ConversationInput,
|
||||||
session: conversation.ChatSession[ChatCompletionMessageParam],
|
session: conversation.ChatSession[ChatCompletionMessageParam],
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Call the API."""
|
"""Call the API."""
|
||||||
|
assert user_input.agent_id
|
||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -185,8 +187,7 @@ class OpenAIConversationEntity(
|
|||||||
]
|
]
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
_chat_message_convert(message, user_input.agent_id)
|
_chat_message_convert(message) for message in session.async_get_messages()
|
||||||
for message in session.async_get_messages()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
@ -212,11 +213,10 @@ class OpenAIConversationEntity(
|
|||||||
messages.append(_message_convert(response))
|
messages.append(_message_convert(response))
|
||||||
|
|
||||||
session.async_add_message(
|
session.async_add_message(
|
||||||
conversation.ChatMessage(
|
conversation.Content(
|
||||||
role=response.role,
|
role=response.role,
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content=response.content or "",
|
content=response.content or "",
|
||||||
native=messages[-1],
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,11 +237,9 @@ class OpenAIConversationEntity(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
session.async_add_message(
|
session.async_add_message(
|
||||||
conversation.ChatMessage(
|
conversation.NativeContent(
|
||||||
role="native",
|
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
content="",
|
content=messages[-1],
|
||||||
native=messages[-1],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ async def test_cleanup(
|
|||||||
assert chat_session.conversation_id != conversation_id
|
assert chat_session.conversation_id != conversation_id
|
||||||
conversation_id = chat_session.conversation_id
|
conversation_id = chat_session.conversation_id
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
@ -127,12 +127,6 @@ async def test_cleanup(
|
|||||||
assert len(chat_session.messages) == 2
|
assert len(chat_session.messages) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_chat_message() -> None:
|
|
||||||
"""Test chat message."""
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
session.ChatMessage(role="native", agent_id=None, content="", native=None)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_add_message(
|
async def test_add_message(
|
||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -144,7 +138,7 @@ async def test_add_message(
|
|||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(role="system", agent_id=None, content="")
|
session.Content(role="system", agent_id=None, content="")
|
||||||
)
|
)
|
||||||
|
|
||||||
# No 2 user messages in a row
|
# No 2 user messages in a row
|
||||||
@ -152,19 +146,19 @@ async def test_add_message(
|
|||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(role="user", agent_id=None, content="")
|
session.Content(role="user", agent_id=None, content="")
|
||||||
)
|
)
|
||||||
|
|
||||||
# No 2 assistant messages in a row
|
# No 2 assistant messages in a row
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(role="assistant", agent_id=None, content="")
|
session.Content(role="assistant", agent_id=None, content="")
|
||||||
)
|
)
|
||||||
assert len(chat_session.messages) == 3
|
assert len(chat_session.messages) == 3
|
||||||
assert chat_session.messages[-1].role == "assistant"
|
assert chat_session.messages[-1].role == "assistant"
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(role="assistant", agent_id=None, content="")
|
session.Content(role="assistant", agent_id=None, content="")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -177,12 +171,12 @@ async def test_message_filtering(
|
|||||||
) as chat_session:
|
) as chat_session:
|
||||||
messages = chat_session.async_get_messages(agent_id=None)
|
messages = chat_session.async_get_messages(agent_id=None)
|
||||||
assert len(messages) == 2
|
assert len(messages) == 2
|
||||||
assert messages[0] == session.ChatMessage(
|
assert messages[0] == session.Content(
|
||||||
role="system",
|
role="system",
|
||||||
agent_id=None,
|
agent_id=None,
|
||||||
content="",
|
content="",
|
||||||
)
|
)
|
||||||
assert messages[1] == session.ChatMessage(
|
assert messages[1] == session.Content(
|
||||||
role="user",
|
role="user",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content=mock_conversation_input.text,
|
content=mock_conversation_input.text,
|
||||||
@ -190,7 +184,7 @@ async def test_message_filtering(
|
|||||||
# Cannot add a second user message in a row
|
# Cannot add a second user message in a row
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.Content(
|
||||||
role="user",
|
role="user",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
@ -198,31 +192,25 @@ async def test_message_filtering(
|
|||||||
)
|
)
|
||||||
|
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
native="assistant-reply-native",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Different agent, native messages will be filtered out.
|
# Different agent, native messages will be filtered out.
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||||
role="native", agent_id="another-mock-agent-id", content="", native=1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.NativeContent(agent_id="mock-agent-id", content=1)
|
||||||
role="native", agent_id="mock-agent-id", content="", native=1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# A non-native message from another agent is not filtered out.
|
# A non-native message from another agent is not filtered out.
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="another-mock-agent-id",
|
agent_id="another-mock-agent-id",
|
||||||
content="Hi!",
|
content="Hi!",
|
||||||
native=1,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -231,17 +219,14 @@ async def test_message_filtering(
|
|||||||
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
||||||
assert len(messages) == 5
|
assert len(messages) == 5
|
||||||
|
|
||||||
assert messages[2] == session.ChatMessage(
|
assert messages[2] == session.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
native="assistant-reply-native",
|
|
||||||
)
|
)
|
||||||
assert messages[3] == session.ChatMessage(
|
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
|
||||||
role="native", agent_id="mock-agent-id", content="", native=1
|
assert messages[4] == session.Content(
|
||||||
)
|
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
||||||
assert messages[4] == session.ChatMessage(
|
|
||||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -361,7 +346,7 @@ async def test_extra_systen_prompt(
|
|||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
@ -401,7 +386,7 @@ async def test_extra_systen_prompt(
|
|||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
chat_session.async_add_message(
|
chat_session.async_add_message(
|
||||||
session.ChatMessage(
|
session.Content(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
agent_id="mock-agent-id",
|
agent_id="mock-agent-id",
|
||||||
content="Hey!",
|
content="Hey!",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user