Make get_chat_session a callback context manager (#137146)

This commit is contained in:
Paulus Schoutsen 2025-02-01 23:37:24 -05:00 committed by GitHub
parent 2ce585463c
commit dd9bd8ef73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 69 additions and 72 deletions

View File

@ -1094,7 +1094,7 @@ class PipelineRun:
# It was already handled, create response and add to chat history # It was already handled, create response and add to chat history
if intent_response is not None: if intent_response is not None:
async with ( with (
chat_session.async_get_chat_session( chat_session.async_get_chat_session(
self.hass, user_input.conversation_id self.hass, user_input.conversation_id
) as session, ) as session,

View File

@ -349,7 +349,7 @@ class DefaultAgent(ConversationEntity):
async def async_process(self, user_input: ConversationInput) -> ConversationResult: async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence.""" """Process a sentence."""
response: intent.IntentResponse | None = None response: intent.IntentResponse | None = None
async with ( with (
chat_session.async_get_chat_session( chat_session.async_get_chat_session(
self.hass, user_input.conversation_id self.hass, user_input.conversation_id
) as session, ) as session,

View File

@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import Generator
from contextlib import asynccontextmanager from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from datetime import datetime from datetime import datetime
import logging import logging
@ -27,12 +27,12 @@ DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log"
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@asynccontextmanager @contextmanager
async def async_get_chat_log( def async_get_chat_log(
hass: HomeAssistant, hass: HomeAssistant,
session: chat_session.ChatSession, session: chat_session.ChatSession,
user_input: ConversationInput, user_input: ConversationInput,
) -> AsyncGenerator[ChatLog]: ) -> Generator[ChatLog]:
"""Return chat log for a specific chat session.""" """Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY) all_history = hass.data.get(DATA_CHAT_HISTORY)
if all_history is None: if all_history is None:

View File

@ -209,7 +209,7 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
async with ( with (
chat_session.async_get_chat_session( chat_session.async_get_chat_session(
self.hass, user_input.conversation_id self.hass, user_input.conversation_id
) as session, ) as session,

View File

@ -155,7 +155,7 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
async with ( with (
chat_session.async_get_chat_session( chat_session.async_get_chat_session(
self.hass, user_input.conversation_id self.hass, user_input.conversation_id
) as session, ) as session,

View File

@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import Generator
from contextlib import asynccontextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -17,8 +17,9 @@ from homeassistant.core import (
HomeAssistant, HomeAssistant,
callback, callback,
) )
from homeassistant.util import dt as dt_util, ulid as ulid_util from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.ulid import ulid_now, ulid_to_bytes
from .event import async_call_later from .event import async_call_later
@ -107,11 +108,11 @@ class SessionCleanup:
self.schedule() self.schedule()
@asynccontextmanager @contextmanager
async def async_get_chat_session( def async_get_chat_session(
hass: HomeAssistant, hass: HomeAssistant,
conversation_id: str | None = None, conversation_id: str | None = None,
) -> AsyncGenerator[ChatSession]: ) -> Generator[ChatSession]:
"""Return a chat session.""" """Return a chat session."""
if session := current_session.get(): if session := current_session.get():
# If a session is already active and it's the requested conversation ID, # If a session is already active and it's the requested conversation ID,
@ -132,7 +133,7 @@ async def async_get_chat_session(
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass) hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
if conversation_id is None: if conversation_id is None:
conversation_id = ulid_util.ulid_now() conversation_id = ulid_now()
elif conversation_id in all_sessions: elif conversation_id in all_sessions:
session = all_sessions[conversation_id] session = all_sessions[conversation_id]
@ -143,8 +144,8 @@ async def async_get_chat_session(
# a new conversation was started. If the user picks their own, they # a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it. # want to track a conversation and we respect it.
try: try:
ulid_util.ulid_to_bytes(conversation_id) ulid_to_bytes(conversation_id)
conversation_id = ulid_util.ulid_now() conversation_id = ulid_now()
except ValueError: except ValueError:
pass pass

View File

@ -50,7 +50,7 @@ async def test_cleanup(
mock_conversation_input: ConversationInput, mock_conversation_input: ConversationInput,
) -> None: ) -> None:
"""Test cleanup of the chat log.""" """Test cleanup of the chat log."""
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -83,7 +83,7 @@ async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None: ) -> None:
"""Test filtering of messages.""" """Test filtering of messages."""
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -115,7 +115,7 @@ async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None: ) -> None:
"""Test filtering of messages.""" """Test filtering of messages."""
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -183,7 +183,7 @@ async def test_llm_api(
mock_conversation_input: ConversationInput, mock_conversation_input: ConversationInput,
) -> None: ) -> None:
"""Test when we reference an LLM API.""" """Test when we reference an LLM API."""
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -204,17 +204,17 @@ async def test_unknown_llm_api(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test when we reference an LLM API that does not exists.""" """Test when we reference an LLM API that does not exists."""
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
): ):
with pytest.raises(ConverseError) as exc_info: await chat_log.async_update_llm_data(
await chat_log.async_update_llm_data( conversing_domain="test",
conversing_domain="test", user_input=mock_conversation_input,
user_input=mock_conversation_input, user_llm_hass_api="unknown-api",
user_llm_hass_api="unknown-api", user_llm_prompt=None,
user_llm_prompt=None, )
)
assert str(exc_info.value) == "Error getting LLM API unknown-api" assert str(exc_info.value) == "Error getting LLM API unknown-api"
assert exc_info.value.as_conversation_result().as_dict() == snapshot assert exc_info.value.as_conversation_result().as_dict() == snapshot
@ -226,17 +226,17 @@ async def test_template_error(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that template error handling works.""" """Test that template error handling works."""
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
): ):
with pytest.raises(ConverseError) as exc_info: await chat_log.async_update_llm_data(
await chat_log.async_update_llm_data( conversing_domain="test",
conversing_domain="test", user_input=mock_conversation_input,
user_input=mock_conversation_input, user_llm_hass_api=None,
user_llm_hass_api=None, user_llm_prompt="{{ invalid_syntax",
user_llm_prompt="{{ invalid_syntax", )
)
assert str(exc_info.value) == "Error rendering prompt" assert str(exc_info.value) == "Error rendering prompt"
assert exc_info.value.as_conversation_result().as_dict() == snapshot assert exc_info.value.as_conversation_result().as_dict() == snapshot
@ -251,24 +251,22 @@ async def test_template_variables(
mock_user.name = "Test User" mock_user.name = "Test User"
mock_conversation_input.context = Context(user_id=mock_user.id) mock_conversation_input.context = Context(user_id=mock_user.id)
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
): ):
with patch( await chat_log.async_update_llm_data(
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user conversing_domain="test",
): user_input=mock_conversation_input,
await chat_log.async_update_llm_data( user_llm_hass_api=None,
conversing_domain="test", user_llm_prompt=(
user_input=mock_conversation_input, "The instance name is {{ ha_name }}. "
user_llm_hass_api=None, "The user name is {{ user_name }}. "
user_llm_prompt=( "The user id is {{ llm_context.context.user_id }}."
"The instance name is {{ ha_name }}. " "The calling platform is {{ llm_context.platform }}."
"The user name is {{ user_name }}. " ),
"The user id is {{ llm_context.context.user_id }}." )
"The calling platform is {{ llm_context.platform }}."
),
)
assert chat_log.user_name == "Test User" assert chat_log.user_name == "Test User"
@ -288,7 +286,7 @@ async def test_extra_systen_prompt(
) )
mock_conversation_input.extra_system_prompt = extra_system_prompt mock_conversation_input.extra_system_prompt = extra_system_prompt
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -313,7 +311,7 @@ async def test_extra_systen_prompt(
conversation_id = chat_log.conversation_id conversation_id = chat_log.conversation_id
mock_conversation_input.extra_system_prompt = None mock_conversation_input.extra_system_prompt = None
async with ( with (
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -330,7 +328,7 @@ async def test_extra_systen_prompt(
# Verify that we take new system prompts # Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2 mock_conversation_input.extra_system_prompt = extra_system_prompt2
async with ( with (
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -355,7 +353,7 @@ async def test_extra_systen_prompt(
# Verify that follow-up conversations with no system prompt take previous one # Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None mock_conversation_input.extra_system_prompt = None
async with ( with (
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -390,7 +388,7 @@ async def test_tool_call(
) as mock_get_tools: ) as mock_get_tools:
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
@ -430,7 +428,7 @@ async def test_tool_call_exception(
) as mock_get_tools: ) as mock_get_tools:
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
async with ( with (
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):

View File

@ -16,7 +16,7 @@ from tests.common import async_fire_time_changed
@pytest.fixture @pytest.fixture
def mock_ulid() -> Generator[Mock]: def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library.""" """Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid" mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now yield mock_ulid_now
@ -37,27 +37,25 @@ async def test_conversation_id(
mock_ulid: Mock, mock_ulid: Mock,
) -> None: ) -> None:
"""Test conversation ID generation.""" """Test conversation ID generation."""
async with chat_session.async_get_chat_session(hass, start_id) as session: with chat_session.async_get_chat_session(hass, start_id) as session:
assert session.conversation_id == given_id assert session.conversation_id == given_id
async def test_context_var(hass: HomeAssistant) -> None: async def test_context_var(hass: HomeAssistant) -> None:
"""Test context var.""" """Test context var."""
async with chat_session.async_get_chat_session(hass) as session: with chat_session.async_get_chat_session(hass) as session:
async with chat_session.async_get_chat_session( with chat_session.async_get_chat_session(
hass, session.conversation_id hass, session.conversation_id
) as session2: ) as session2:
assert session is session2 assert session is session2
async with chat_session.async_get_chat_session(hass, None) as session2: with chat_session.async_get_chat_session(hass, None) as session2:
assert session.conversation_id != session2.conversation_id assert session.conversation_id != session2.conversation_id
async with chat_session.async_get_chat_session( with chat_session.async_get_chat_session(hass, "something else") as session2:
hass, "something else"
) as session2:
assert session.conversation_id != session2.conversation_id assert session.conversation_id != session2.conversation_id
async with chat_session.async_get_chat_session( with chat_session.async_get_chat_session(
hass, ulid_util.ulid_now() hass, ulid_util.ulid_now()
) as session2: ) as session2:
assert session.conversation_id != session2.conversation_id assert session.conversation_id != session2.conversation_id
@ -67,11 +65,11 @@ async def test_cleanup(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:
"""Test cleanup of the chat session.""" """Test cleanup of the chat session."""
async with chat_session.async_get_chat_session(hass) as session: with chat_session.async_get_chat_session(hass) as session:
conversation_id = session.conversation_id conversation_id = session.conversation_id
# Reuse conversation ID to ensure we can chat with same session # Reuse conversation ID to ensure we can chat with same session
async with chat_session.async_get_chat_session(hass, conversation_id) as session: with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id == conversation_id assert session.conversation_id == conversation_id
# Set the last updated to be older than the timeout # Set the last updated to be older than the timeout
@ -85,7 +83,7 @@ async def test_cleanup(
) )
# Should not be cleaned up, but it should have scheduled another cleanup # Should not be cleaned up, but it should have scheduled another cleanup
async with chat_session.async_get_chat_session(hass, conversation_id) as session: with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id == conversation_id assert session.conversation_id == conversation_id
async_fire_time_changed( async_fire_time_changed(
@ -94,5 +92,5 @@ async def test_cleanup(
) )
# It should be cleaned up now and we start a new conversation # It should be cleaned up now and we start a new conversation
async with chat_session.async_get_chat_session(hass, conversation_id) as session: with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id != conversation_id assert session.conversation_id != conversation_id