mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 22:37:11 +00:00
Make get_chat_session a callback context manager (#137146)
This commit is contained in:
parent
2ce585463c
commit
dd9bd8ef73
@ -1094,7 +1094,7 @@ class PipelineRun:
|
|||||||
|
|
||||||
# It was already handled, create response and add to chat history
|
# It was already handled, create response and add to chat history
|
||||||
if intent_response is not None:
|
if intent_response is not None:
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(
|
chat_session.async_get_chat_session(
|
||||||
self.hass, user_input.conversation_id
|
self.hass, user_input.conversation_id
|
||||||
) as session,
|
) as session,
|
||||||
|
@ -349,7 +349,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
response: intent.IntentResponse | None = None
|
response: intent.IntentResponse | None = None
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(
|
chat_session.async_get_chat_session(
|
||||||
self.hass, user_input.conversation_id
|
self.hass, user_input.conversation_id
|
||||||
) as session,
|
) as session,
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import Generator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
@ -27,12 +27,12 @@ DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log"
|
|||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@contextmanager
|
||||||
async def async_get_chat_log(
|
def async_get_chat_log(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
session: chat_session.ChatSession,
|
session: chat_session.ChatSession,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput,
|
||||||
) -> AsyncGenerator[ChatLog]:
|
) -> Generator[ChatLog]:
|
||||||
"""Return chat log for a specific chat session."""
|
"""Return chat log for a specific chat session."""
|
||||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||||
if all_history is None:
|
if all_history is None:
|
||||||
|
@ -209,7 +209,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(
|
chat_session.async_get_chat_session(
|
||||||
self.hass, user_input.conversation_id
|
self.hass, user_input.conversation_id
|
||||||
) as session,
|
) as session,
|
||||||
|
@ -155,7 +155,7 @@ class OpenAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(
|
chat_session.async_get_chat_session(
|
||||||
self.hass, user_input.conversation_id
|
self.hass, user_input.conversation_id
|
||||||
) as session,
|
) as session,
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import Generator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@ -17,8 +17,9 @@ from homeassistant.core import (
|
|||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
from homeassistant.util import dt as dt_util
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
from homeassistant.util.ulid import ulid_now, ulid_to_bytes
|
||||||
|
|
||||||
from .event import async_call_later
|
from .event import async_call_later
|
||||||
|
|
||||||
@ -107,11 +108,11 @@ class SessionCleanup:
|
|||||||
self.schedule()
|
self.schedule()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@contextmanager
|
||||||
async def async_get_chat_session(
|
def async_get_chat_session(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
) -> AsyncGenerator[ChatSession]:
|
) -> Generator[ChatSession]:
|
||||||
"""Return a chat session."""
|
"""Return a chat session."""
|
||||||
if session := current_session.get():
|
if session := current_session.get():
|
||||||
# If a session is already active and it's the requested conversation ID,
|
# If a session is already active and it's the requested conversation ID,
|
||||||
@ -132,7 +133,7 @@ async def async_get_chat_session(
|
|||||||
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
|
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
|
||||||
|
|
||||||
if conversation_id is None:
|
if conversation_id is None:
|
||||||
conversation_id = ulid_util.ulid_now()
|
conversation_id = ulid_now()
|
||||||
|
|
||||||
elif conversation_id in all_sessions:
|
elif conversation_id in all_sessions:
|
||||||
session = all_sessions[conversation_id]
|
session = all_sessions[conversation_id]
|
||||||
@ -143,8 +144,8 @@ async def async_get_chat_session(
|
|||||||
# a new conversation was started. If the user picks their own, they
|
# a new conversation was started. If the user picks their own, they
|
||||||
# want to track a conversation and we respect it.
|
# want to track a conversation and we respect it.
|
||||||
try:
|
try:
|
||||||
ulid_util.ulid_to_bytes(conversation_id)
|
ulid_to_bytes(conversation_id)
|
||||||
conversation_id = ulid_util.ulid_now()
|
conversation_id = ulid_now()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ async def test_cleanup(
|
|||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test cleanup of the chat log."""
|
"""Test cleanup of the chat log."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -83,7 +83,7 @@ async def test_add_message(
|
|||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test filtering of messages."""
|
"""Test filtering of messages."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -115,7 +115,7 @@ async def test_message_filtering(
|
|||||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test filtering of messages."""
|
"""Test filtering of messages."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -183,7 +183,7 @@ async def test_llm_api(
|
|||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when we reference an LLM API."""
|
"""Test when we reference an LLM API."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -204,11 +204,11 @@ async def test_unknown_llm_api(
|
|||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when we reference an LLM API that does not exists."""
|
"""Test when we reference an LLM API that does not exists."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
pytest.raises(ConverseError) as exc_info,
|
||||||
):
|
):
|
||||||
with pytest.raises(ConverseError) as exc_info:
|
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
@ -226,11 +226,11 @@ async def test_template_error(
|
|||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that template error handling works."""
|
"""Test that template error handling works."""
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
pytest.raises(ConverseError) as exc_info,
|
||||||
):
|
):
|
||||||
with pytest.raises(ConverseError) as exc_info:
|
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
user_input=mock_conversation_input,
|
user_input=mock_conversation_input,
|
||||||
@ -251,12 +251,10 @@ async def test_template_variables(
|
|||||||
mock_user.name = "Test User"
|
mock_user.name = "Test User"
|
||||||
mock_conversation_input.context = Context(user_id=mock_user.id)
|
mock_conversation_input.context = Context(user_id=mock_user.id)
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||||
with patch(
|
|
||||||
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
|
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
conversing_domain="test",
|
conversing_domain="test",
|
||||||
@ -288,7 +286,7 @@ async def test_extra_systen_prompt(
|
|||||||
)
|
)
|
||||||
mock_conversation_input.extra_system_prompt = extra_system_prompt
|
mock_conversation_input.extra_system_prompt = extra_system_prompt
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -313,7 +311,7 @@ async def test_extra_systen_prompt(
|
|||||||
conversation_id = chat_log.conversation_id
|
conversation_id = chat_log.conversation_id
|
||||||
mock_conversation_input.extra_system_prompt = None
|
mock_conversation_input.extra_system_prompt = None
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -330,7 +328,7 @@ async def test_extra_systen_prompt(
|
|||||||
# Verify that we take new system prompts
|
# Verify that we take new system prompts
|
||||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -355,7 +353,7 @@ async def test_extra_systen_prompt(
|
|||||||
# Verify that follow-up conversations with no system prompt take previous one
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
mock_conversation_input.extra_system_prompt = None
|
mock_conversation_input.extra_system_prompt = None
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -390,7 +388,7 @@ async def test_tool_call(
|
|||||||
) as mock_get_tools:
|
) as mock_get_tools:
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
@ -430,7 +428,7 @@ async def test_tool_call_exception(
|
|||||||
) as mock_get_tools:
|
) as mock_get_tools:
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
async with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
|
@ -16,7 +16,7 @@ from tests.common import async_fire_time_changed
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ulid() -> Generator[Mock]:
|
def mock_ulid() -> Generator[Mock]:
|
||||||
"""Mock the ulid library."""
|
"""Mock the ulid library."""
|
||||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||||
mock_ulid_now.return_value = "mock-ulid"
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
yield mock_ulid_now
|
yield mock_ulid_now
|
||||||
|
|
||||||
@ -37,27 +37,25 @@ async def test_conversation_id(
|
|||||||
mock_ulid: Mock,
|
mock_ulid: Mock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test conversation ID generation."""
|
"""Test conversation ID generation."""
|
||||||
async with chat_session.async_get_chat_session(hass, start_id) as session:
|
with chat_session.async_get_chat_session(hass, start_id) as session:
|
||||||
assert session.conversation_id == given_id
|
assert session.conversation_id == given_id
|
||||||
|
|
||||||
|
|
||||||
async def test_context_var(hass: HomeAssistant) -> None:
|
async def test_context_var(hass: HomeAssistant) -> None:
|
||||||
"""Test context var."""
|
"""Test context var."""
|
||||||
async with chat_session.async_get_chat_session(hass) as session:
|
with chat_session.async_get_chat_session(hass) as session:
|
||||||
async with chat_session.async_get_chat_session(
|
with chat_session.async_get_chat_session(
|
||||||
hass, session.conversation_id
|
hass, session.conversation_id
|
||||||
) as session2:
|
) as session2:
|
||||||
assert session is session2
|
assert session is session2
|
||||||
|
|
||||||
async with chat_session.async_get_chat_session(hass, None) as session2:
|
with chat_session.async_get_chat_session(hass, None) as session2:
|
||||||
assert session.conversation_id != session2.conversation_id
|
assert session.conversation_id != session2.conversation_id
|
||||||
|
|
||||||
async with chat_session.async_get_chat_session(
|
with chat_session.async_get_chat_session(hass, "something else") as session2:
|
||||||
hass, "something else"
|
|
||||||
) as session2:
|
|
||||||
assert session.conversation_id != session2.conversation_id
|
assert session.conversation_id != session2.conversation_id
|
||||||
|
|
||||||
async with chat_session.async_get_chat_session(
|
with chat_session.async_get_chat_session(
|
||||||
hass, ulid_util.ulid_now()
|
hass, ulid_util.ulid_now()
|
||||||
) as session2:
|
) as session2:
|
||||||
assert session.conversation_id != session2.conversation_id
|
assert session.conversation_id != session2.conversation_id
|
||||||
@ -67,11 +65,11 @@ async def test_cleanup(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test cleanup of the chat session."""
|
"""Test cleanup of the chat session."""
|
||||||
async with chat_session.async_get_chat_session(hass) as session:
|
with chat_session.async_get_chat_session(hass) as session:
|
||||||
conversation_id = session.conversation_id
|
conversation_id = session.conversation_id
|
||||||
|
|
||||||
# Reuse conversation ID to ensure we can chat with same session
|
# Reuse conversation ID to ensure we can chat with same session
|
||||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
assert session.conversation_id == conversation_id
|
assert session.conversation_id == conversation_id
|
||||||
|
|
||||||
# Set the last updated to be older than the timeout
|
# Set the last updated to be older than the timeout
|
||||||
@ -85,7 +83,7 @@ async def test_cleanup(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Should not be cleaned up, but it should have scheduled another cleanup
|
# Should not be cleaned up, but it should have scheduled another cleanup
|
||||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
assert session.conversation_id == conversation_id
|
assert session.conversation_id == conversation_id
|
||||||
|
|
||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
@ -94,5 +92,5 @@ async def test_cleanup(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# It should be cleaned up now and we start a new conversation
|
# It should be cleaned up now and we start a new conversation
|
||||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
assert session.conversation_id != conversation_id
|
assert session.conversation_id != conversation_id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user