diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index f2b2f1c1ea4..cfc7261410a 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1094,7 +1094,7 @@ class PipelineRun: # It was already handled, create response and add to chat history if intent_response is not None: - async with ( + with ( chat_session.async_get_chat_session( self.hass, user_input.conversation_id ) as session, diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 99a1a09e52b..c4a8f7ea7eb 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -349,7 +349,7 @@ class DefaultAgent(ConversationEntity): async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process a sentence.""" response: intent.IntentResponse | None = None - async with ( + with ( chat_session.async_get_chat_session( self.hass, user_input.conversation_id ) as session, diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/session.py index c23d6575d6c..c32d61333a0 100644 --- a/homeassistant/components/conversation/session.py +++ b/homeassistant/components/conversation/session.py @@ -2,8 +2,8 @@ from __future__ import annotations -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field, replace from datetime import datetime import logging @@ -27,12 +27,12 @@ DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log" LOGGER = logging.getLogger(__name__) -@asynccontextmanager -async def async_get_chat_log( +@contextmanager +def async_get_chat_log( hass: HomeAssistant, session: chat_session.ChatSession, user_input: ConversationInput, -) -> AsyncGenerator[ChatLog]: +) -> Generator[ChatLog]: """Return chat log for a specific chat session.""" all_history = hass.data.get(DATA_CHAT_HISTORY) if all_history is None: diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index d4982797e22..53ee4e1f880 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -209,7 +209,7 @@ class GoogleGenerativeAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - async with ( + with ( chat_session.async_get_chat_session( self.hass, user_input.conversation_id ) as session, diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 330a6fe9a34..e19ad9becaf 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -155,7 +155,7 @@ class OpenAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - async with ( + with ( chat_session.async_get_chat_session( self.hass, user_input.conversation_id ) as session, diff --git a/homeassistant/helpers/chat_session.py b/homeassistant/helpers/chat_session.py index 4cfa91bc555..686272dd834 100644 --- a/homeassistant/helpers/chat_session.py +++ b/homeassistant/helpers/chat_session.py @@ -2,8 +2,8 @@ from __future__ import annotations -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager +from collections.abc import Generator +from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -17,8 +17,9 @@ from homeassistant.core import ( HomeAssistant, callback, ) -from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.util import dt as dt_util from homeassistant.util.hass_dict import HassKey +from homeassistant.util.ulid import ulid_now, ulid_to_bytes from .event import async_call_later @@ -107,11 +108,11 @@ class SessionCleanup: self.schedule() -@asynccontextmanager -async def async_get_chat_session( +@contextmanager +def async_get_chat_session( hass: HomeAssistant, conversation_id: str | None = None, -) -> AsyncGenerator[ChatSession]: +) -> Generator[ChatSession]: """Return a chat session.""" if session := current_session.get(): # If a session is already active and it's the requested conversation ID, @@ -132,7 +133,7 @@ async def async_get_chat_session( hass.data[DATA_CHAT_SESSION_CLEANUP] = SessionCleanup(hass) if conversation_id is None: - conversation_id = ulid_util.ulid_now() + conversation_id = ulid_now() elif conversation_id in all_sessions: session = all_sessions[conversation_id] @@ -143,8 +144,8 @@ async def async_get_chat_session( # a new conversation was started. If the user picks their own, they # want to track a conversation and we respect it. try: - ulid_util.ulid_to_bytes(conversation_id) - conversation_id = ulid_util.ulid_now() + ulid_to_bytes(conversation_id) + conversation_id = ulid_now() except ValueError: pass diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_session.py index 4a9c662fffa..3943f41a62b 100644 --- a/tests/components/conversation/test_session.py +++ b/tests/components/conversation/test_session.py @@ -50,7 +50,7 @@ async def test_cleanup( mock_conversation_input: ConversationInput, ) -> None: """Test cleanup of the chat log.""" - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -83,7 +83,7 @@ async def test_add_message( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: """Test filtering of messages.""" - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -115,7 +115,7 @@ async def test_message_filtering( hass: HomeAssistant, mock_conversation_input: ConversationInput ) -> None: """Test filtering of messages.""" - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -183,7 +183,7 @@ async def test_llm_api( mock_conversation_input: ConversationInput, ) -> None: """Test when we reference an LLM API.""" - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -204,17 +204,17 @@ async def test_unknown_llm_api( snapshot: SnapshotAssertion, ) -> None: """Test when we reference an LLM API that does not exists.""" - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + pytest.raises(ConverseError) as exc_info, ): - with pytest.raises(ConverseError) as exc_info: - await chat_log.async_update_llm_data( - conversing_domain="test", - user_input=mock_conversation_input, - user_llm_hass_api="unknown-api", - user_llm_prompt=None, - ) + await chat_log.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api="unknown-api", + user_llm_prompt=None, + ) assert str(exc_info.value) == "Error getting LLM API unknown-api" assert exc_info.value.as_conversation_result().as_dict() == snapshot @@ -226,17 +226,17 @@ async def test_template_error( snapshot: SnapshotAssertion, ) -> None: """Test that template error handling works.""" - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + pytest.raises(ConverseError) as exc_info, ): - with pytest.raises(ConverseError) as exc_info: - await chat_log.async_update_llm_data( - conversing_domain="test", - user_input=mock_conversation_input, - user_llm_hass_api=None, - user_llm_prompt="{{ invalid_syntax", - ) + await chat_log.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt="{{ invalid_syntax", + ) assert str(exc_info.value) == "Error rendering prompt" assert exc_info.value.as_conversation_result().as_dict() == snapshot @@ -251,24 +251,22 @@ async def test_template_variables( mock_user.name = "Test User" mock_conversation_input.context = Context(user_id=mock_user.id) - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), ): - with patch( - "homeassistant.auth.AuthManager.async_get_user", return_value=mock_user - ): - await chat_log.async_update_llm_data( - conversing_domain="test", - user_input=mock_conversation_input, - user_llm_hass_api=None, - user_llm_prompt=( - "The instance name is {{ ha_name }}. " - "The user name is {{ user_name }}. " - "The user id is {{ llm_context.context.user_id }}." - "The calling platform is {{ llm_context.platform }}." - ), - ) + await chat_log.async_update_llm_data( + conversing_domain="test", + user_input=mock_conversation_input, + user_llm_hass_api=None, + user_llm_prompt=( + "The instance name is {{ ha_name }}. " + "The user name is {{ user_name }}. " + "The user id is {{ llm_context.context.user_id }}." + "The calling platform is {{ llm_context.platform }}." + ), + ) assert chat_log.user_name == "Test User" @@ -288,7 +286,7 @@ async def test_extra_systen_prompt( ) mock_conversation_input.extra_system_prompt = extra_system_prompt - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -313,7 +311,7 @@ async def test_extra_systen_prompt( conversation_id = chat_log.conversation_id mock_conversation_input.extra_system_prompt = None - async with ( + with ( chat_session.async_get_chat_session(hass, conversation_id) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -330,7 +328,7 @@ async def test_extra_systen_prompt( # Verify that we take new system prompts mock_conversation_input.extra_system_prompt = extra_system_prompt2 - async with ( + with ( chat_session.async_get_chat_session(hass, conversation_id) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -355,7 +353,7 @@ async def test_extra_systen_prompt( # Verify that follow-up conversations with no system prompt take previous one mock_conversation_input.extra_system_prompt = None - async with ( + with ( chat_session.async_get_chat_session(hass, conversation_id) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -390,7 +388,7 @@ async def test_tool_call( ) as mock_get_tools: mock_get_tools.return_value = [mock_tool] - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): @@ -430,7 +428,7 @@ async def test_tool_call_exception( ) as mock_get_tools: mock_get_tools.return_value = [mock_tool] - async with ( + with ( chat_session.async_get_chat_session(hass) as session, async_get_chat_log(hass, session, mock_conversation_input) as chat_log, ): diff --git a/tests/helpers/test_chat_session.py b/tests/helpers/test_chat_session.py index a11f4126886..f6c2fe5057d 100644 --- a/tests/helpers/test_chat_session.py +++ b/tests/helpers/test_chat_session.py @@ -16,7 +16,7 @@ from tests.common import async_fire_time_changed @pytest.fixture def mock_ulid() -> Generator[Mock]: """Mock the ulid library.""" - with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: + with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now: mock_ulid_now.return_value = "mock-ulid" yield mock_ulid_now @@ -37,27 +37,25 @@ async def test_conversation_id( mock_ulid: Mock, ) -> None: """Test conversation ID generation.""" - async with chat_session.async_get_chat_session(hass, start_id) as session: + with chat_session.async_get_chat_session(hass, start_id) as session: assert session.conversation_id == given_id async def test_context_var(hass: HomeAssistant) -> None: """Test context var.""" - async with chat_session.async_get_chat_session(hass) as session: - async with chat_session.async_get_chat_session( + with chat_session.async_get_chat_session(hass) as session: + with chat_session.async_get_chat_session( hass, session.conversation_id ) as session2: assert session is session2 - async with chat_session.async_get_chat_session(hass, None) as session2: + with chat_session.async_get_chat_session(hass, None) as session2: assert session.conversation_id != session2.conversation_id - async with chat_session.async_get_chat_session( - hass, "something else" - ) as session2: + with chat_session.async_get_chat_session(hass, "something else") as session2: assert session.conversation_id != session2.conversation_id - async with chat_session.async_get_chat_session( + with chat_session.async_get_chat_session( hass, ulid_util.ulid_now() ) as session2: assert session.conversation_id != session2.conversation_id @@ -67,11 +65,11 @@ async def test_cleanup( hass: HomeAssistant, ) -> None: """Test cleanup of the chat session.""" - async with chat_session.async_get_chat_session(hass) as session: + with chat_session.async_get_chat_session(hass) as session: conversation_id = session.conversation_id # Reuse conversation ID to ensure we can chat with same session - async with chat_session.async_get_chat_session(hass, conversation_id) as session: + with chat_session.async_get_chat_session(hass, conversation_id) as session: assert session.conversation_id == conversation_id # Set the last updated to be older than the timeout @@ -85,7 +83,7 @@ async def test_cleanup( ) # Should not be cleaned up, but it should have scheduled another cleanup - async with chat_session.async_get_chat_session(hass, conversation_id) as session: + with chat_session.async_get_chat_session(hass, conversation_id) as session: assert session.conversation_id == conversation_id async_fire_time_changed( @@ -94,5 +92,5 @@ async def test_cleanup( ) # It should be cleaned up now and we start a new conversation - async with chat_session.async_get_chat_session(hass, conversation_id) as session: + with chat_session.async_get_chat_session(hass, conversation_id) as session: assert session.conversation_id != conversation_id