Make get_chat_session a callback context manager (#137146)

This commit is contained in:
Paulus Schoutsen 2025-02-01 23:37:24 -05:00 committed by GitHub
parent 2ce585463c
commit dd9bd8ef73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 69 additions and 72 deletions

View File

@ -1094,7 +1094,7 @@ class PipelineRun:
# It was already handled, create response and add to chat history
if intent_response is not None:
async with (
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,

View File

@ -349,7 +349,7 @@ class DefaultAgent(ConversationEntity):
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
response: intent.IntentResponse | None = None
async with (
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,

View File

@ -2,8 +2,8 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime
import logging
@ -27,12 +27,12 @@ DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log"
LOGGER = logging.getLogger(__name__)
@asynccontextmanager
async def async_get_chat_log(
@contextmanager
def async_get_chat_log(
hass: HomeAssistant,
session: chat_session.ChatSession,
user_input: ConversationInput,
) -> AsyncGenerator[ChatLog]:
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
if all_history is None:

View File

@ -209,7 +209,7 @@ class GoogleGenerativeAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
async with (
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,

View File

@ -155,7 +155,7 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
async with (
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,

View File

@ -2,8 +2,8 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@ -17,8 +17,9 @@ from homeassistant.core import (
HomeAssistant,
callback,
)
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util import dt as dt_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.ulid import ulid_now, ulid_to_bytes
from .event import async_call_later
@ -107,11 +108,11 @@ class SessionCleanup:
self.schedule()
@asynccontextmanager
async def async_get_chat_session(
@contextmanager
def async_get_chat_session(
hass: HomeAssistant,
conversation_id: str | None = None,
) -> AsyncGenerator[ChatSession]:
) -> Generator[ChatSession]:
"""Return a chat session."""
if session := current_session.get():
# If a session is already active and it's the requested conversation ID,
@ -132,7 +133,7 @@ async def async_get_chat_session(
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
if conversation_id is None:
conversation_id = ulid_util.ulid_now()
conversation_id = ulid_now()
elif conversation_id in all_sessions:
session = all_sessions[conversation_id]
@ -143,8 +144,8 @@ async def async_get_chat_session(
# a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it.
try:
ulid_util.ulid_to_bytes(conversation_id)
conversation_id = ulid_util.ulid_now()
ulid_to_bytes(conversation_id)
conversation_id = ulid_now()
except ValueError:
pass

View File

@ -50,7 +50,7 @@ async def test_cleanup(
mock_conversation_input: ConversationInput,
) -> None:
"""Test cleanup of the chat log."""
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -83,7 +83,7 @@ async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -115,7 +115,7 @@ async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -183,7 +183,7 @@ async def test_llm_api(
mock_conversation_input: ConversationInput,
) -> None:
"""Test when we reference an LLM API."""
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -204,17 +204,17 @@ async def test_unknown_llm_api(
snapshot: SnapshotAssertion,
) -> None:
"""Test when we reference an LLM API that does not exists."""
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
):
with pytest.raises(ConverseError) as exc_info:
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="unknown-api",
user_llm_prompt=None,
)
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="unknown-api",
user_llm_prompt=None,
)
assert str(exc_info.value) == "Error getting LLM API unknown-api"
assert exc_info.value.as_conversation_result().as_dict() == snapshot
@ -226,17 +226,17 @@ async def test_template_error(
snapshot: SnapshotAssertion,
) -> None:
"""Test that template error handling works."""
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
):
with pytest.raises(ConverseError) as exc_info:
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt="{{ invalid_syntax",
)
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt="{{ invalid_syntax",
)
assert str(exc_info.value) == "Error rendering prompt"
assert exc_info.value.as_conversation_result().as_dict() == snapshot
@ -251,24 +251,22 @@ async def test_template_variables(
mock_user.name = "Test User"
mock_conversation_input.context = Context(user_id=mock_user.id)
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
with patch(
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=(
"The instance name is {{ ha_name }}. "
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
"The calling platform is {{ llm_context.platform }}."
),
)
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=(
"The instance name is {{ ha_name }}. "
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
"The calling platform is {{ llm_context.platform }}."
),
)
assert chat_log.user_name == "Test User"
@ -288,7 +286,7 @@ async def test_extra_systen_prompt(
)
mock_conversation_input.extra_system_prompt = extra_system_prompt
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -313,7 +311,7 @@ async def test_extra_systen_prompt(
conversation_id = chat_log.conversation_id
mock_conversation_input.extra_system_prompt = None
async with (
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -330,7 +328,7 @@ async def test_extra_systen_prompt(
# Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2
async with (
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -355,7 +353,7 @@ async def test_extra_systen_prompt(
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None
async with (
with (
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -390,7 +388,7 @@ async def test_tool_call(
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
@ -430,7 +428,7 @@ async def test_tool_call_exception(
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
async with (
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):

View File

@ -16,7 +16,7 @@ from tests.common import async_fire_time_changed
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@ -37,27 +37,25 @@ async def test_conversation_id(
mock_ulid: Mock,
) -> None:
"""Test conversation ID generation."""
async with chat_session.async_get_chat_session(hass, start_id) as session:
with chat_session.async_get_chat_session(hass, start_id) as session:
assert session.conversation_id == given_id
async def test_context_var(hass: HomeAssistant) -> None:
"""Test context var."""
async with chat_session.async_get_chat_session(hass) as session:
async with chat_session.async_get_chat_session(
with chat_session.async_get_chat_session(hass) as session:
with chat_session.async_get_chat_session(
hass, session.conversation_id
) as session2:
assert session is session2
async with chat_session.async_get_chat_session(hass, None) as session2:
with chat_session.async_get_chat_session(hass, None) as session2:
assert session.conversation_id != session2.conversation_id
async with chat_session.async_get_chat_session(
hass, "something else"
) as session2:
with chat_session.async_get_chat_session(hass, "something else") as session2:
assert session.conversation_id != session2.conversation_id
async with chat_session.async_get_chat_session(
with chat_session.async_get_chat_session(
hass, ulid_util.ulid_now()
) as session2:
assert session.conversation_id != session2.conversation_id
@ -67,11 +65,11 @@ async def test_cleanup(
hass: HomeAssistant,
) -> None:
"""Test cleanup of the chat session."""
async with chat_session.async_get_chat_session(hass) as session:
with chat_session.async_get_chat_session(hass) as session:
conversation_id = session.conversation_id
# Reuse conversation ID to ensure we can chat with same session
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id == conversation_id
# Set the last updated to be older than the timeout
@ -85,7 +83,7 @@ async def test_cleanup(
)
# Should not be cleaned up, but it should have scheduled another cleanup
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id == conversation_id
async_fire_time_changed(
@ -94,5 +92,5 @@ async def test_cleanup(
)
# It should be cleaned up now and we start a new conversation
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
with chat_session.async_get_chat_session(hass, conversation_id) as session:
assert session.conversation_id != conversation_id