mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Extract conversation ID generation to helper (#137062)
* Extract conversation ID generation to helper * Allow nested get_chat_log calls
This commit is contained in:
parent
30314ca32b
commit
2f6640707b
@ -33,7 +33,7 @@ from homeassistant.components.tts import (
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.helpers.collection import (
|
||||
CHANGE_UPDATED,
|
||||
CollectionError,
|
||||
@ -1094,13 +1094,18 @@ class PipelineRun:
|
||||
|
||||
# It was already handled, create response and add to chat history
|
||||
if intent_response is not None:
|
||||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as chat_session:
|
||||
async with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(
|
||||
self.hass, session, user_input
|
||||
) as chat_log,
|
||||
):
|
||||
speech: str = intent_response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
chat_log.async_add_message(
|
||||
conversation.Content(
|
||||
role="assistant",
|
||||
agent_id=agent_id,
|
||||
@ -1109,7 +1114,7 @@ class PipelineRun:
|
||||
)
|
||||
conversation_result = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=chat_session.conversation_id,
|
||||
conversation_id=session.conversation_id,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -48,20 +48,14 @@ from .default_agent import DefaultAgent, async_setup_default_agent
|
||||
from .entity import ConversationEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
|
||||
from .session import (
|
||||
ChatSession,
|
||||
Content,
|
||||
ConverseError,
|
||||
NativeContent,
|
||||
async_get_chat_session,
|
||||
)
|
||||
from .session import ChatLog, Content, ConverseError, NativeContent, async_get_chat_log
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"HOME_ASSISTANT_AGENT",
|
||||
"OLD_HOME_ASSISTANT_AGENT",
|
||||
"ChatSession",
|
||||
"ChatLog",
|
||||
"Content",
|
||||
"ConversationEntity",
|
||||
"ConversationEntityFeature",
|
||||
@ -73,7 +67,7 @@ __all__ = [
|
||||
"async_conversation_trace_append",
|
||||
"async_converse",
|
||||
"async_get_agent_info",
|
||||
"async_get_chat_session",
|
||||
"async_get_chat_log",
|
||||
"async_set_agent",
|
||||
"async_setup",
|
||||
"async_unset_agent",
|
||||
|
@ -42,6 +42,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
||||
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
chat_session,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
floor_registry as fr,
|
||||
@ -62,7 +63,7 @@ from .const import (
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput, ConversationResult
|
||||
from .session import Content, async_get_chat_session
|
||||
from .session import Content, async_get_chat_log
|
||||
from .trace import ConversationTraceEventType, async_conversation_trace_append
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -348,7 +349,12 @@ class DefaultAgent(ConversationEntity):
|
||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||
"""Process a sentence."""
|
||||
response: intent.IntentResponse | None = None
|
||||
async with async_get_chat_session(self.hass, user_input) as chat_session:
|
||||
async with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||
):
|
||||
# Check if a trigger matched
|
||||
if trigger_result := await self.async_recognize_sentence_trigger(
|
||||
user_input
|
||||
@ -373,7 +379,7 @@ class DefaultAgent(ConversationEntity):
|
||||
)
|
||||
|
||||
speech: str = response.speech.get("plain", {}).get("speech", "")
|
||||
chat_session.async_add_message(
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id=user_input.agent_id,
|
||||
@ -382,7 +388,7 @@ class DefaultAgent(ConversationEntity):
|
||||
)
|
||||
|
||||
return ConversationResult(
|
||||
response=response, conversation_id=chat_session.conversation_id
|
||||
response=response, conversation_id=session.conversation_id
|
||||
)
|
||||
|
||||
async def _async_process_intent_result(
|
||||
|
@ -5,25 +5,16 @@ from __future__ import annotations
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
HassJob,
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import intent, llm, template
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.helpers import chat_session, intent, llm, template
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
@ -31,100 +22,36 @@ from . import trace
|
||||
from .const import DOMAIN
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey(
|
||||
"conversation_chat_session"
|
||||
)
|
||||
DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey(
|
||||
"conversation_chat_session_cleanup"
|
||||
)
|
||||
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log")
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||
|
||||
|
||||
class SessionCleanup:
|
||||
"""Helper to clean up the history."""
|
||||
|
||||
unsub: CALLBACK_TYPE | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the history cleanup."""
|
||||
self.hass = hass
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
||||
self.cleanup_job = HassJob(
|
||||
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
|
||||
)
|
||||
|
||||
@callback
|
||||
def schedule(self) -> None:
|
||||
"""Schedule the cleanup."""
|
||||
if self.unsub:
|
||||
return
|
||||
self.unsub = async_call_later(
|
||||
self.hass,
|
||||
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
||||
self.cleanup_job,
|
||||
)
|
||||
|
||||
@callback
|
||||
def _on_hass_stop(self, event: Event) -> None:
|
||||
"""Cancel the cleanup on shutdown."""
|
||||
if self.unsub:
|
||||
self.unsub()
|
||||
self.unsub = None
|
||||
|
||||
@callback
|
||||
def _cleanup(self, now: datetime) -> None:
|
||||
"""Clean up the history and schedule follow-up if necessary."""
|
||||
self.unsub = None
|
||||
all_history = self.hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
# We mutate original object because current commands could be
|
||||
# yielding history based on it.
|
||||
for conversation_id, history in list(all_history.items()):
|
||||
if history.last_updated + CONVERSATION_TIMEOUT < now:
|
||||
del all_history[conversation_id]
|
||||
|
||||
# Still conversations left, check again in timeout time.
|
||||
if all_history:
|
||||
self.schedule()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_get_chat_session(
|
||||
async def async_get_chat_log(
|
||||
hass: HomeAssistant,
|
||||
session: chat_session.ChatSession,
|
||||
user_input: ConversationInput,
|
||||
) -> AsyncGenerator[ChatSession]:
|
||||
"""Return chat session."""
|
||||
) -> AsyncGenerator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||
if all_history is None:
|
||||
all_history = {}
|
||||
hass.data[DATA_CHAT_HISTORY] = all_history
|
||||
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
|
||||
|
||||
history: ChatSession | None = None
|
||||
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid_util.ulid_now()
|
||||
|
||||
elif history := all_history.get(user_input.conversation_id):
|
||||
conversation_id = user_input.conversation_id
|
||||
|
||||
else:
|
||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||
# If an old OLID is passed in, we will generate a new one to indicate
|
||||
# a new conversation was started. If the user picks their own, they
|
||||
# want to track a conversation and we respect it.
|
||||
try:
|
||||
ulid_util.ulid_to_bytes(user_input.conversation_id)
|
||||
conversation_id = ulid_util.ulid_now()
|
||||
except ValueError:
|
||||
conversation_id = user_input.conversation_id
|
||||
history = all_history.get(session.conversation_id)
|
||||
|
||||
if history:
|
||||
history = replace(history, messages=history.messages.copy())
|
||||
else:
|
||||
history = ChatSession(hass, conversation_id, user_input.agent_id)
|
||||
history = ChatLog(hass, session.conversation_id, user_input.agent_id)
|
||||
|
||||
@callback
|
||||
def do_cleanup() -> None:
|
||||
"""Handle cleanup."""
|
||||
all_history.pop(session.conversation_id)
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
message: Content = Content(
|
||||
role="user",
|
||||
@ -142,8 +69,7 @@ async def async_get_chat_session(
|
||||
return
|
||||
|
||||
history.last_updated = dt_util.utcnow()
|
||||
all_history[conversation_id] = history
|
||||
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
|
||||
all_history[session.conversation_id] = history
|
||||
|
||||
|
||||
class ConverseError(HomeAssistantError):
|
||||
@ -187,8 +113,8 @@ class NativeContent[_NativeT]:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSession[_NativeT]:
|
||||
"""Class holding all information for a specific conversation."""
|
||||
class ChatLog[_NativeT]:
|
||||
"""Class holding the chat history of a specific conversation."""
|
||||
|
||||
hass: HomeAssistant
|
||||
conversation_id: str
|
||||
|
@ -18,7 +18,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import (
|
||||
@ -209,15 +209,18 @@ class GoogleGenerativeAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as session:
|
||||
return await self._async_handle_message(user_input, session)
|
||||
async with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||
):
|
||||
return await self._async_handle_message(user_input, chat_log)
|
||||
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatSession[genai_types.ContentDict],
|
||||
session: conversation.ChatLog[genai_types.ContentDict],
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
|
||||
|
@ -23,7 +23,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
@ -155,15 +155,18 @@ class OpenAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
async with conversation.async_get_chat_session(
|
||||
self.hass, user_input
|
||||
) as session:
|
||||
return await self._async_handle_message(user_input, session)
|
||||
async with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||
):
|
||||
return await self._async_handle_message(user_input, chat_log)
|
||||
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
session: conversation.ChatSession[ChatCompletionMessageParam],
|
||||
session: conversation.ChatLog[ChatCompletionMessageParam],
|
||||
) -> conversation.ConversationResult:
|
||||
"""Call the API."""
|
||||
assert user_input.agent_id
|
||||
|
160
homeassistant/helpers/chat_session.py
Normal file
160
homeassistant/helpers/chat_session.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Helper to organize chat sessions between integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
HassJob,
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from .event import async_call_later
|
||||
|
||||
DATA_CHAT_SESSION: HassKey[dict[str, ChatSession]] = HassKey("chat_session")
|
||||
DATA_CHAT_SESSION_CLEANUP: HassKey[SessionCleanup] = HassKey("chat_session_cleanup")
|
||||
|
||||
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
||||
|
||||
current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSession:
|
||||
"""Represent a chat session."""
|
||||
|
||||
conversation_id: str
|
||||
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
||||
_cleanup_callbacks: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||
|
||||
@callback
|
||||
def async_updated(self) -> None:
|
||||
"""Update the last updated time."""
|
||||
self.last_updated = dt_util.utcnow()
|
||||
|
||||
@callback
|
||||
def async_on_cleanup(self, cb: CALLBACK_TYPE) -> None:
|
||||
"""Register a callback to clean up the session."""
|
||||
self._cleanup_callbacks.append(cb)
|
||||
|
||||
@callback
|
||||
def async_cleanup(self) -> None:
|
||||
"""Call all clean up callbacks."""
|
||||
for cb in self._cleanup_callbacks:
|
||||
cb()
|
||||
self._cleanup_callbacks.clear()
|
||||
|
||||
|
||||
class SessionCleanup:
|
||||
"""Helper to clean up the stale sessions."""
|
||||
|
||||
unsub: CALLBACK_TYPE | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the session cleanup."""
|
||||
self.hass = hass
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
||||
self.cleanup_job = HassJob(
|
||||
self._cleanup, "chat_session_cleanup", job_type=HassJobType.Callback
|
||||
)
|
||||
|
||||
@callback
|
||||
def schedule(self) -> None:
|
||||
"""Schedule the cleanup."""
|
||||
if self.unsub:
|
||||
return
|
||||
self.unsub = async_call_later(
|
||||
self.hass,
|
||||
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
||||
self.cleanup_job,
|
||||
)
|
||||
|
||||
@callback
|
||||
def _on_hass_stop(self, event: Event) -> None:
|
||||
"""Cancel the cleanup on shutdown."""
|
||||
if self.unsub:
|
||||
self.unsub()
|
||||
self.unsub = None
|
||||
|
||||
@callback
|
||||
def _cleanup(self, now: datetime) -> None:
|
||||
"""Clean up the history and schedule follow-up if necessary."""
|
||||
self.unsub = None
|
||||
all_sessions = self.hass.data[DATA_CHAT_SESSION]
|
||||
|
||||
# We mutate original object because current commands could be
|
||||
# yielding session based on it.
|
||||
for conversation_id, session in list(all_sessions.items()):
|
||||
if session.last_updated + CONVERSATION_TIMEOUT < now:
|
||||
del all_sessions[conversation_id]
|
||||
session.async_cleanup()
|
||||
|
||||
# Still conversations left, check again in timeout time.
|
||||
if all_sessions:
|
||||
self.schedule()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_get_chat_session(
|
||||
hass: HomeAssistant,
|
||||
conversation_id: str | None = None,
|
||||
) -> AsyncGenerator[ChatSession]:
|
||||
"""Return a chat session."""
|
||||
if session := current_session.get():
|
||||
# If a session is already active and it's the requested conversation ID,
|
||||
# return that. We won't update the last updated time in this case.
|
||||
if session.conversation_id == conversation_id:
|
||||
yield session
|
||||
return
|
||||
|
||||
# If it's not the same conversation ID, we will create a new session
|
||||
# because it might be a conversation agent calling a tool that is talking
|
||||
# to another LLM.
|
||||
session = None
|
||||
|
||||
all_sessions = hass.data.get(DATA_CHAT_SESSION)
|
||||
if all_sessions is None:
|
||||
all_sessions = {}
|
||||
hass.data[DATA_CHAT_SESSION] = all_sessions
|
||||
hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass)
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = ulid_util.ulid_now()
|
||||
|
||||
elif conversation_id in all_sessions:
|
||||
session = all_sessions[conversation_id]
|
||||
|
||||
else:
|
||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||
# If an old ULID is passed in, we will generate a new one to indicate
|
||||
# a new conversation was started. If the user picks their own, they
|
||||
# want to track a conversation and we respect it.
|
||||
try:
|
||||
ulid_util.ulid_to_bytes(conversation_id)
|
||||
conversation_id = ulid_util.ulid_now()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if session is None:
|
||||
session = ChatSession(conversation_id)
|
||||
|
||||
current_session.set(session)
|
||||
yield session
|
||||
current_session.set(None)
|
||||
|
||||
session.last_updated = dt_util.utcnow()
|
||||
all_sessions[conversation_id] = session
|
||||
hass.data[DATA_CHAT_SESSION_CLEANUP].schedule()
|
@ -8,10 +8,17 @@ import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, session
|
||||
from homeassistant.components.conversation import (
|
||||
Content,
|
||||
ConversationInput,
|
||||
ConverseError,
|
||||
NativeContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from homeassistant.components.conversation.session import DATA_CHAT_HISTORY
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
@ -38,127 +45,69 @@ def mock_ulid() -> Generator[Mock]:
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("start_id", "given_id"),
|
||||
[
|
||||
(None, "mock-ulid"),
|
||||
# This ULID is not known as a session
|
||||
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
|
||||
("not-a-ulid", "not-a-ulid"),
|
||||
],
|
||||
)
|
||||
async def test_conversation_id(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
mock_ulid: Mock,
|
||||
start_id: str | None,
|
||||
given_id: str,
|
||||
) -> None:
|
||||
"""Test conversation ID generation."""
|
||||
mock_conversation_input.conversation_id = start_id
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert chat_session.conversation_id == given_id
|
||||
|
||||
|
||||
async def test_cleanup(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Mock cleanup of the conversation session."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert len(chat_session.messages) == 2
|
||||
conversation_id = chat_session.conversation_id
|
||||
|
||||
# Generate session entry.
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
# Because we didn't add a message to the session in the last block,
|
||||
# the conversation was not be persisted and we get a new ID
|
||||
assert chat_session.conversation_id != conversation_id
|
||||
conversation_id = chat_session.conversation_id
|
||||
chat_session.async_add_message(
|
||||
session.Content(
|
||||
"""Test cleanup of the chat log."""
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
conversation_id = session.conversation_id
|
||||
# Add message so it persists
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
content="",
|
||||
)
|
||||
)
|
||||
assert len(chat_session.messages) == 3
|
||||
|
||||
# Reuse conversation ID to ensure we can chat with same session
|
||||
mock_conversation_input.conversation_id = conversation_id
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert len(chat_session.messages) == 4
|
||||
assert chat_session.conversation_id == conversation_id
|
||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
# Set the last updated to be older than the timeout
|
||||
hass.data[session.DATA_CHAT_HISTORY][conversation_id].last_updated = (
|
||||
dt_util.utcnow() + session.CONVERSATION_TIMEOUT
|
||||
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
|
||||
)
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
|
||||
hass,
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
# Should not be cleaned up, but it should have scheduled another cleanup
|
||||
mock_conversation_input.conversation_id = conversation_id
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert len(chat_session.messages) == 4
|
||||
assert chat_session.conversation_id == conversation_id
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1)
|
||||
)
|
||||
|
||||
# It should be cleaned up now and we start a new conversation
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert chat_session.conversation_id != conversation_id
|
||||
assert len(chat_session.messages) == 2
|
||||
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
||||
|
||||
|
||||
async def test_add_message(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
assert len(chat_session.messages) == 2
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
assert len(chat_log.messages) == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.Content(role="system", agent_id=None, content="")
|
||||
chat_log.async_add_message(
|
||||
Content(role="system", agent_id=None, content="")
|
||||
)
|
||||
|
||||
# No 2 user messages in a row
|
||||
assert chat_session.messages[1].role == "user"
|
||||
assert chat_log.messages[1].role == "user"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.Content(role="user", agent_id=None, content="")
|
||||
)
|
||||
chat_log.async_add_message(Content(role="user", agent_id=None, content=""))
|
||||
|
||||
# No 2 assistant messages in a row
|
||||
chat_session.async_add_message(
|
||||
session.Content(role="assistant", agent_id=None, content="")
|
||||
)
|
||||
assert len(chat_session.messages) == 3
|
||||
assert chat_session.messages[-1].role == "assistant"
|
||||
chat_log.async_add_message(Content(role="assistant", agent_id=None, content=""))
|
||||
assert len(chat_log.messages) == 3
|
||||
assert chat_log.messages[-1].role == "assistant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.Content(role="assistant", agent_id=None, content="")
|
||||
chat_log.async_add_message(
|
||||
Content(role="assistant", agent_id=None, content="")
|
||||
)
|
||||
|
||||
|
||||
@ -166,66 +115,65 @@ async def test_message_filtering(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test filtering of messages."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
messages = chat_session.async_get_messages(agent_id=None)
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
messages = chat_log.async_get_messages(agent_id=None)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == session.Content(
|
||||
assert messages[0] == Content(
|
||||
role="system",
|
||||
agent_id=None,
|
||||
content="",
|
||||
)
|
||||
assert messages[1] == session.Content(
|
||||
assert messages[1] == Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content=mock_conversation_input.text,
|
||||
)
|
||||
# Cannot add a second user message in a row
|
||||
with pytest.raises(ValueError):
|
||||
chat_session.async_add_message(
|
||||
session.Content(
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="user",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
|
||||
chat_session.async_add_message(
|
||||
session.Content(
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
# Different agent, native messages will be filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.NativeContent(agent_id="mock-agent-id", content=1)
|
||||
chat_log.async_add_message(
|
||||
NativeContent(agent_id="another-mock-agent-id", content=1)
|
||||
)
|
||||
chat_log.async_add_message(NativeContent(agent_id="mock-agent-id", content=1))
|
||||
# A non-native message from another agent is not filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.Content(
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="another-mock-agent-id",
|
||||
content="Hi!",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(chat_session.messages) == 6
|
||||
assert len(chat_log.messages) == 6
|
||||
|
||||
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
||||
messages = chat_log.async_get_messages(agent_id="mock-agent-id")
|
||||
assert len(messages) == 5
|
||||
|
||||
assert messages[2] == session.Content(
|
||||
assert messages[2] == Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
assert messages[3] == session.NativeContent(agent_id="mock-agent-id", content=1)
|
||||
assert messages[4] == session.Content(
|
||||
assert messages[3] == NativeContent(agent_id="mock-agent-id", content=1)
|
||||
assert messages[4] == Content(
|
||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!"
|
||||
)
|
||||
|
||||
@ -235,18 +183,19 @@ async def test_llm_api(
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test when we reference an LLM API."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
assert isinstance(chat_session.llm_api, llm.APIInstance)
|
||||
assert chat_session.llm_api.api.id == "assist"
|
||||
assert isinstance(chat_log.llm_api, llm.APIInstance)
|
||||
assert chat_log.llm_api.api.id == "assist"
|
||||
|
||||
|
||||
async def test_unknown_llm_api(
|
||||
@ -255,11 +204,12 @@ async def test_unknown_llm_api(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test when we reference an LLM API that does not exists."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
with pytest.raises(session.ConverseError) as exc_info:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
with pytest.raises(ConverseError) as exc_info:
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="unknown-api",
|
||||
@ -276,11 +226,12 @@ async def test_template_error(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
with pytest.raises(session.ConverseError) as exc_info:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
with pytest.raises(ConverseError) as exc_info:
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
@ -300,13 +251,14 @@ async def test_template_variables(
|
||||
mock_user.name = "Test User"
|
||||
mock_conversation_input.context = Context(user_id=mock_user.id)
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
with patch(
|
||||
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
|
||||
):
|
||||
await chat_session.async_update_llm_data(
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
@ -318,12 +270,12 @@ async def test_template_variables(
|
||||
),
|
||||
)
|
||||
|
||||
assert chat_session.user_name == "Test User"
|
||||
assert chat_log.user_name == "Test User"
|
||||
|
||||
assert "The instance name is test home." in chat_session.messages[0].content
|
||||
assert "The user name is Test User." in chat_session.messages[0].content
|
||||
assert "The user id is 12345." in chat_session.messages[0].content
|
||||
assert "The calling platform is test." in chat_session.messages[0].content
|
||||
assert "The instance name is test home." in chat_log.messages[0].content
|
||||
assert "The user name is Test User." in chat_log.messages[0].content
|
||||
assert "The user id is 12345." in chat_log.messages[0].content
|
||||
assert "The calling platform is test." in chat_log.messages[0].content
|
||||
|
||||
|
||||
async def test_extra_systen_prompt(
|
||||
@ -336,82 +288,86 @@ async def test_extra_systen_prompt(
|
||||
)
|
||||
mock_conversation_input.extra_system_prompt = extra_system_prompt
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.Content(
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
|
||||
assert chat_session.extra_system_prompt == extra_system_prompt
|
||||
assert chat_session.messages[0].content.endswith(extra_system_prompt)
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
mock_conversation_input.conversation_id = chat_session.conversation_id
|
||||
conversation_id = chat_log.conversation_id
|
||||
mock_conversation_input.extra_system_prompt = None
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
assert chat_session.extra_system_prompt == extra_system_prompt
|
||||
assert chat_session.messages[0].content.endswith(extra_system_prompt)
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt)
|
||||
|
||||
# Verify that we take new system prompts
|
||||
mock_conversation_input.extra_system_prompt = extra_system_prompt2
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
chat_session.async_add_message(
|
||||
session.Content(
|
||||
chat_log.async_add_message(
|
||||
Content(
|
||||
role="assistant",
|
||||
agent_id="mock-agent-id",
|
||||
content="Hey!",
|
||||
)
|
||||
)
|
||||
|
||||
assert chat_session.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert extra_system_prompt not in chat_session.messages[0].content
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert extra_system_prompt not in chat_log.messages[0].content
|
||||
|
||||
# Verify that follow-up conversations with no system prompt take previous one
|
||||
mock_conversation_input.extra_system_prompt = None
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
assert chat_session.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_log.messages[0].content.endswith(extra_system_prompt2)
|
||||
|
||||
|
||||
async def test_tool_call(
|
||||
@ -434,16 +390,17 @@ async def test_tool_call(
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_session.async_call_tool(
|
||||
result = await chat_log.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
@ -473,16 +430,17 @@ async def test_tool_call_exception(
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
async with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_session.async_call_tool(
|
||||
result = await chat_log.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
|
98
tests/helpers/test_chat_session.py
Normal file
98
tests/helpers/test_chat_session.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Test the chat session helper."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid library."""
|
||||
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("start_id", "given_id"),
|
||||
[
|
||||
(None, "mock-ulid"),
|
||||
# This ULID is not known as a session
|
||||
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
|
||||
("not-a-ulid", "not-a-ulid"),
|
||||
],
|
||||
)
|
||||
async def test_conversation_id(
|
||||
hass: HomeAssistant,
|
||||
start_id: str | None,
|
||||
given_id: str,
|
||||
mock_ulid: Mock,
|
||||
) -> None:
|
||||
"""Test conversation ID generation."""
|
||||
async with chat_session.async_get_chat_session(hass, start_id) as session:
|
||||
assert session.conversation_id == given_id
|
||||
|
||||
|
||||
async def test_context_var(hass: HomeAssistant) -> None:
|
||||
"""Test context var."""
|
||||
async with chat_session.async_get_chat_session(hass) as session:
|
||||
async with chat_session.async_get_chat_session(
|
||||
hass, session.conversation_id
|
||||
) as session2:
|
||||
assert session is session2
|
||||
|
||||
async with chat_session.async_get_chat_session(hass, None) as session2:
|
||||
assert session.conversation_id != session2.conversation_id
|
||||
|
||||
async with chat_session.async_get_chat_session(
|
||||
hass, "something else"
|
||||
) as session2:
|
||||
assert session.conversation_id != session2.conversation_id
|
||||
|
||||
async with chat_session.async_get_chat_session(
|
||||
hass, ulid_util.ulid_now()
|
||||
) as session2:
|
||||
assert session.conversation_id != session2.conversation_id
|
||||
|
||||
|
||||
async def test_cleanup(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test cleanup of the chat session."""
|
||||
async with chat_session.async_get_chat_session(hass) as session:
|
||||
conversation_id = session.conversation_id
|
||||
|
||||
# Reuse conversation ID to ensure we can chat with same session
|
||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
assert session.conversation_id == conversation_id
|
||||
|
||||
# Set the last updated to be older than the timeout
|
||||
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
|
||||
)
|
||||
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
# Should not be cleaned up, but it should have scheduled another cleanup
|
||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
assert session.conversation_id == conversation_id
|
||||
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
# It should be cleaned up now and we start a new conversation
|
||||
async with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
assert session.conversation_id != conversation_id
|
Loading…
x
Reference in New Issue
Block a user