mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Fix user input not added to chat log from contextvar (#138173)
This commit is contained in:
parent
c2bb376c43
commit
cabb406270
@ -38,21 +38,23 @@ def async_get_chat_log(
|
||||
user_input: ConversationInput | None = None,
|
||||
) -> Generator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
if chat_log := current_chat_log.get():
|
||||
# If a chat log is already active and it's the requested conversation ID,
|
||||
# return that. We won't update the last updated time in this case.
|
||||
if chat_log.conversation_id == session.conversation_id:
|
||||
yield chat_log
|
||||
return
|
||||
# If a chat log is already active and it's the requested conversation ID,
|
||||
# return that. We won't update the last updated time in this case.
|
||||
if (
|
||||
chat_log := current_chat_log.get()
|
||||
) and chat_log.conversation_id == session.conversation_id:
|
||||
if user_input is not None:
|
||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
yield chat_log
|
||||
return
|
||||
|
||||
all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
|
||||
if all_chat_logs is None:
|
||||
all_chat_logs = {}
|
||||
hass.data[DATA_CHAT_LOGS] = all_chat_logs
|
||||
|
||||
chat_log = all_chat_logs.get(session.conversation_id)
|
||||
|
||||
if chat_log:
|
||||
if chat_log := all_chat_logs.get(session.conversation_id):
|
||||
chat_log = replace(chat_log, content=chat_log.content.copy())
|
||||
else:
|
||||
chat_log = ChatLog(hass, session.conversation_id)
|
||||
@ -395,12 +397,10 @@ class ChatLog:
|
||||
if llm_api:
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
|
||||
extra_system_prompt = (
|
||||
if extra_system_prompt := (
|
||||
# Take new system prompt if one was given
|
||||
user_input.extra_system_prompt or self.extra_system_prompt
|
||||
)
|
||||
|
||||
if extra_system_prompt:
|
||||
):
|
||||
prompt_parts.append(extra_system_prompt)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
@ -602,3 +602,26 @@ async def test_add_delta_content_stream_errors(
|
||||
stream([{"role": role}]),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
async def test_chat_log_reuse(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test that we can reuse a chat log."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
assert chat_log.conversation_id == session.conversation_id
|
||||
assert len(chat_log.content) == 1
|
||||
|
||||
with async_get_chat_log(hass, session) as chat_log2:
|
||||
assert chat_log2 is chat_log
|
||||
assert len(chat_log.content) == 1
|
||||
|
||||
with async_get_chat_log(hass, session, mock_conversation_input) as chat_log2:
|
||||
assert chat_log2 is chat_log
|
||||
assert len(chat_log.content) == 2
|
||||
assert chat_log.content[1].role == "user"
|
||||
assert chat_log.content[1].content == mock_conversation_input.text
|
||||
|
Loading…
x
Reference in New Issue
Block a user