Fix user input not added to chat log from contextvar (#138173)

This commit is contained in:
Paulus Schoutsen 2025-02-09 19:34:25 -05:00 committed by GitHub
parent c2bb376c43
commit cabb406270
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 13 deletions

View File

@ -38,21 +38,23 @@ def async_get_chat_log(
user_input: ConversationInput | None = None,
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
if chat_log := current_chat_log.get():
# If a chat log is already active and it's the requested conversation ID,
# return that. We won't update the last updated time in this case.
if chat_log.conversation_id == session.conversation_id:
yield chat_log
return
# If a chat log is already active and it's the requested conversation ID,
# return that. We won't update the last updated time in this case.
if (
chat_log := current_chat_log.get()
) and chat_log.conversation_id == session.conversation_id:
if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text))
yield chat_log
return
all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
if all_chat_logs is None:
all_chat_logs = {}
hass.data[DATA_CHAT_LOGS] = all_chat_logs
chat_log = all_chat_logs.get(session.conversation_id)
if chat_log:
if chat_log := all_chat_logs.get(session.conversation_id):
chat_log = replace(chat_log, content=chat_log.content.copy())
else:
chat_log = ChatLog(hass, session.conversation_id)
@ -395,12 +397,10 @@ class ChatLog:
if llm_api:
prompt_parts.append(llm_api.api_prompt)
extra_system_prompt = (
if extra_system_prompt := (
# Take new system prompt if one was given
user_input.extra_system_prompt or self.extra_system_prompt
)
if extra_system_prompt:
):
prompt_parts.append(extra_system_prompt)
prompt = "\n".join(prompt_parts)

View File

@ -602,3 +602,26 @@ async def test_add_delta_content_stream_errors(
stream([{"role": role}]),
):
pass
async def test_chat_log_reuse(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test that we can reuse a chat log."""
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session) as chat_log,
):
assert chat_log.conversation_id == session.conversation_id
assert len(chat_log.content) == 1
with async_get_chat_log(hass, session) as chat_log2:
assert chat_log2 is chat_log
assert len(chat_log.content) == 1
with async_get_chat_log(hass, session, mock_conversation_input) as chat_log2:
assert chat_log2 is chat_log
assert len(chat_log.content) == 2
assert chat_log.content[1].role == "user"
assert chat_log.content[1].content == mock_conversation_input.text