diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 5dbd19ba275..a060a769907 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -38,21 +38,23 @@ def async_get_chat_log( user_input: ConversationInput | None = None, ) -> Generator[ChatLog]: """Return chat log for a specific chat session.""" - if chat_log := current_chat_log.get(): - # If a chat log is already active and it's the requested conversation ID, - # return that. We won't update the last updated time in this case. - if chat_log.conversation_id == session.conversation_id: - yield chat_log - return + # If a chat log is already active and it's the requested conversation ID, + # return that. We won't update the last updated time in this case. + if ( + chat_log := current_chat_log.get() + ) and chat_log.conversation_id == session.conversation_id: + if user_input is not None: + chat_log.async_add_user_content(UserContent(content=user_input.text)) + + yield chat_log + return all_chat_logs = hass.data.get(DATA_CHAT_LOGS) if all_chat_logs is None: all_chat_logs = {} hass.data[DATA_CHAT_LOGS] = all_chat_logs - chat_log = all_chat_logs.get(session.conversation_id) - - if chat_log: + if chat_log := all_chat_logs.get(session.conversation_id): chat_log = replace(chat_log, content=chat_log.content.copy()) else: chat_log = ChatLog(hass, session.conversation_id) @@ -395,12 +397,10 @@ class ChatLog: if llm_api: prompt_parts.append(llm_api.api_prompt) - extra_system_prompt = ( + if extra_system_prompt := ( # Take new system prompt if one was given user_input.extra_system_prompt or self.extra_system_prompt - ) - - if extra_system_prompt: + ): prompt_parts.append(extra_system_prompt) prompt = "\n".join(prompt_parts) diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index 090904c7063..0c11d19aab2 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -602,3 +602,26 @@ async def test_add_delta_content_stream_errors( stream([{"role": role}]), ): pass + + +async def test_chat_log_reuse( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test that we can reuse a chat log.""" + with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session) as chat_log, + ): + assert chat_log.conversation_id == session.conversation_id + assert len(chat_log.content) == 1 + + with async_get_chat_log(hass, session) as chat_log2: + assert chat_log2 is chat_log + assert len(chat_log.content) == 1 + + with async_get_chat_log(hass, session, mock_conversation_input) as chat_log2: + assert chat_log2 is chat_log + assert len(chat_log.content) == 2 + assert chat_log.content[1].role == "user" + assert chat_log.content[1].content == mock_conversation_input.text