mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Make get_chat_session a callback context manager (#137146)
This commit is contained in:
parent
2ce585463c
commit
dd9bd8ef73
@ -1094,7 +1094,7 @@ class PipelineRun:
|
||||
|
||||
# It was already handled, create response and add to chat history
|
||||
if intent_response is not None:
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
|
@ -349,7 +349,7 @@ class DefaultAgent(ConversationEntity):
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
response: intent.IntentResponse | None = None
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime
|
||||
import logging
|
||||
@ -27,12 +27,12 @@ DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log"
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_get_chat_log(
|
||||
@contextmanager
|
||||
def async_get_chat_log(
|
||||
hass: HomeAssistant,
|
||||
session: chat_session.ChatSession,
|
||||
user_input: ConversationInput,
|
||||
) -> AsyncGenerator[ChatLog]:
|
||||
) -> Generator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||
if all_history is None:
|
||||
|
@ -209,7 +209,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
|
@ -155,7 +155,7 @@ class OpenAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
@ -17,8 +17,9 @@ from homeassistant.core import (
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.ulid import ulid_now, ulid_to_bytes
|
||||
|
||||
from .event import async_call_later
|
||||
|
||||
@ -107,11 +108,11 @@ class SessionCleanup:
|
||||
self.schedule()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_get_chat_session(
|
||||
@contextmanager
|
||||
def async_get_chat_session(
|
||||
hass: HomeAssistant,
|
||||
conversation_id: str | None = None,
|
||||
) -> AsyncGenerator[ChatSession]:
|
||||
) -> Generator[ChatSession]:
|
||||
"""Return a chat session."""
|
||||
if session := current_session.get():
|
||||
# If a session is already active and it's the requested conversation ID,
|
||||
@ -132,7 +133,7 @@ async def async_get_chat_session(
|
||||
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = ulid_util.ulid_now()
|
||||
conversation_id = ulid_now()
|
||||
|
||||
elif conversation_id in all_sessions:
|
||||
session = all_sessions[conversation_id]
|
||||
@ -143,8 +144,8 @@ async def async_get_chat_session(
|
||||
# a new conversation was started. If the user picks their own, they
|
||||
# want to track a conversation and we respect it.
|
||||
try:
|
||||
ulid_util.ulid_to_bytes(conversation_id)
|
||||
conversation_id = ulid_util.ulid_now()
|
||||
ulid_to_bytes(conversation_id)
|
||||
conversation_id = ulid_now()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
@ -50,7 +50,7 @@ async def test_cleanup(
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test cleanup of the chat log."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -83,7 +83,7 @@ async def test_add_message(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -115,7 +115,7 @@ async def test_message_filtering(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -183,7 +183,7 @@ async def test_llm_api(
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test when we reference an LLM API."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -204,17 +204,17 @@ async def test_unknown_llm_api(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test when we reference an LLM API that does not exists."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
pytest.raises(ConverseError) as exc_info,
|
||||
):
|
||||
with pytest.raises(ConverseError) as exc_info:
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="unknown-api",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="unknown-api",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "Error getting LLM API unknown-api"
|
||||
assert exc_info.value.as_conversation_result().as_dict() == snapshot
|
||||
@ -226,17 +226,17 @@ async def test_template_error(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
pytest.raises(ConverseError) as exc_info,
|
||||
):
|
||||
with pytest.raises(ConverseError) as exc_info:
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt="{{ invalid_syntax",
|
||||
)
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt="{{ invalid_syntax",
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "Error rendering prompt"
|
||||
assert exc_info.value.as_conversation_result().as_dict() == snapshot
|
||||
@ -251,24 +251,22 @@ async def test_template_variables(
|
||||
mock_user.name = "Test User"
|
||||
mock_conversation_input.context = Context(user_id=mock_user.id)
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||
):
|
||||
with patch(
|
||||
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=(
|
||||
"The instance name is {{ ha_name }}. "
|
||||
"The user name is {{ user_name }}. "
|
||||
"The user id is {{ llm_context.context.user_id }}."
|
||||
"The calling platform is {{ llm_context.platform }}."
|
||||
),
|
||||
)
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=(
|
||||
"The instance name is {{ ha_name }}. "
|
||||
"The user name is {{ user_name }}. "
|
||||
"The user id is {{ llm_context.context.user_id }}."
|
||||
"The calling platform is {{ llm_context.platform }}."
|
||||
),
|
||||
)
|
||||
|
||||
assert chat_log.user_name == "Test User"
|
||||
|
||||
@ -288,7 +286,7 @@ async def test_extra_systen_prompt(
|
||||
)
|
||||
mock_conversation_input.extra_system_prompt = extra_system_prompt
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -313,7 +311,7 @@ async def test_extra_systen_prompt(
|
||||
conversation_id = chat_log.conversation_id
|
||||
mock_conversation_input.extra_system_prompt = None
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -330,7 +328,7 @@ async def test_extra_systen_prompt(
|
||||
# Verify that we take new system prompts
|
||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -355,7 +353,7 @@ async def test_extra_systen_prompt(
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
mock_conversation_input.extra_system_prompt = None
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -390,7 +388,7 @@ async def test_tool_call(
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
@ -430,7 +428,7 @@ async def test_tool_call_exception(
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
async with (
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
|
@ -16,7 +16,7 @@ from tests.common import async_fire_time_changed
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
@ -37,27 +37,25 @@ async def test_conversation_id(
|
||||
mock_ulid: Mock,
|
||||
) -> None:
|
||||
"""Test conversation ID generation."""
|
||||
async with chat_session.async_get_chat_session(hass, start_id) as session:
|
||||
with chat_session.async_get_chat_session(hass, start_id) as session:
|
||||
assert session.conversation_id == given_id
|
||||
|
||||
|
||||
async def test_context_var(hass: HomeAssistant) -> None:
|
||||
"""Test context var."""
|
||||
async with chat_session.async_get_chat_session(hass) as session:
|
||||
async with chat_session.async_get_chat_session(
|
||||
with chat_session.async_get_chat_session(hass) as session:
|
||||
with chat_session.async_get_chat_session(
|
||||
hass, session.conversation_id
|
||||
) as session2:
|
||||
assert session is session2
|
||||
|
||||
async with chat_session.async_get_chat_session(hass, None) as session2:
|
||||
with chat_session.async_get_chat_session(hass, None) as session2:
|
||||
assert session.conversation_id != session2.conversation_id
|
||||
|
||||
async with chat_session.async_get_chat_session(
|
||||
hass, "something else"
|
||||
) as session2:
|
||||
with chat_session.async_get_chat_session(hass, "something else") as session2:
|
||||
assert session.conversation_id != session2.conversation_id
|
||||
|
||||
async with chat_session.async_get_chat_session(
|
||||
with chat_session.async_get_chat_session(
|
||||
hass, ulid_util.ulid_now()
|
||||
) as session2:
|
||||
assert session.conversation_id != session2.conversation_id
|
||||
@ -67,11 +65,11 @@ async def test_cleanup(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test cleanup of the chat session."""
|
||||
async with chat_session.async_get_chat_session(hass) as session:
|
||||
with chat_session.async_get_chat_session(hass) as session:
|
||||
conversation_id = session.conversation_id
|
||||
|
||||
# Reuse conversation ID to ensure we can chat with same session
|
||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
assert session.conversation_id == conversation_id
|
||||
|
||||
# Set the last updated to be older than the timeout
|
||||
@ -85,7 +83,7 @@ async def test_cleanup(
|
||||
)
|
||||
|
||||
# Should not be cleaned up, but it should have scheduled another cleanup
|
||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
assert session.conversation_id == conversation_id
|
||||
|
||||
async_fire_time_changed(
|
||||
@ -94,5 +92,5 @@ async def test_cleanup(
|
||||
)
|
||||
|
||||
# It should be cleaned up now and we start a new conversation
|
||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
assert session.conversation_id != conversation_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user