Simplify ChatLog dependencies (#146351)

This commit is contained in:
Paulus Schoutsen 2025-06-15 17:41:15 -04:00 committed by GitHub
parent 5f5869ffc6
commit fa21269f0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 86 additions and 70 deletions

View File

@ -366,11 +366,11 @@ class AnthropicConversationEntity(
options = self.entry.options
try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()

View File

@ -14,12 +14,11 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.helpers import chat_session, frame, intent, llm, template
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
@ -359,7 +358,7 @@ class ChatLog:
self,
llm_context: llm.LLMContext,
prompt: str,
language: str,
language: str | None,
user_name: str | None = None,
) -> str:
try:
@ -373,7 +372,7 @@ class ChatLog:
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=language)
intent_response = intent.IntentResponse(language=language or "")
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem with my template",
@ -392,14 +391,25 @@ class ChatLog:
user_llm_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_context = llm.LLMContext(
platform=conversing_domain,
context=user_input.context,
language=user_input.language,
assistant=DOMAIN,
device_id=user_input.device_id,
frame.report_usage(
"ChatLog.async_update_llm_data",
breaks_in_ha_version="2026.1",
)
return await self.async_provide_llm_data(
llm_context=user_input.as_llm_context(conversing_domain),
user_llm_hass_api=user_llm_hass_api,
user_llm_prompt=user_llm_prompt,
user_extra_system_prompt=user_input.extra_system_prompt,
)
async def async_provide_llm_data(
self,
llm_context: llm.LLMContext,
user_llm_hass_api: str | list[str] | None = None,
user_llm_prompt: str | None = None,
user_extra_system_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_api: llm.APIInstance | None = None
if user_llm_hass_api:
@ -413,10 +423,12 @@ class ChatLog:
LOGGER.error(
"Error getting LLM API %s for %s: %s",
user_llm_hass_api,
conversing_domain,
llm_context.platform,
err,
)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response = intent.IntentResponse(
language=llm_context.language or ""
)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Error preparing LLM API",
@ -430,10 +442,10 @@ class ChatLog:
user_name: str | None = None
if (
user_input.context
and user_input.context.user_id
llm_context.context
and llm_context.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
user := await self.hass.auth.async_get_user(llm_context.context.user_id)
)
):
user_name = user.name
@ -443,7 +455,7 @@ class ChatLog:
await self._async_expand_prompt_template(
llm_context,
(user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
user_input.language,
llm_context.language,
user_name,
)
)
@ -455,14 +467,14 @@ class ChatLog:
await self._async_expand_prompt_template(
llm_context,
llm.BASE_PROMPT,
user_input.language,
llm_context.language,
user_name,
)
)
if extra_system_prompt := (
# Take new system prompt if one was given
user_input.extra_system_prompt or self.extra_system_prompt
user_extra_system_prompt or self.extra_system_prompt
):
prompt_parts.append(extra_system_prompt)

View File

@ -7,7 +7,9 @@ from dataclasses import dataclass
from typing import Any, Literal
from homeassistant.core import Context
from homeassistant.helpers import intent
from homeassistant.helpers import intent, llm
from .const import DOMAIN
@dataclass(frozen=True)
@ -56,6 +58,16 @@ class ConversationInput:
"extra_system_prompt": self.extra_system_prompt,
}
def as_llm_context(self, conversing_domain: str) -> llm.LLMContext:
"""Return input as an LLM context."""
return llm.LLMContext(
platform=conversing_domain,
context=self.context,
language=self.language,
assistant=DOMAIN,
device_id=self.device_id,
)
@dataclass(slots=True)
class ConversationResult:

View File

@ -73,11 +73,11 @@ class GoogleGenerativeAIConversationEntity(
options = self.entry.options
try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()

View File

@ -219,11 +219,11 @@ class OllamaConversationEntity(
settings = {**self.entry.data, **self.entry.options}
try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
settings.get(CONF_LLM_HASS_API),
settings.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()

View File

@ -279,11 +279,11 @@ class OpenAIConversationEntity(
options = self.entry.options
try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()

View File

@ -1779,11 +1779,11 @@ async def test_chat_log_tts_streaming(
conversation_input,
) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=conversation_input,
await chat_log.async_provide_llm_data(
conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
user_extra_system_prompt=conversation_input.extra_system_prompt,
)
async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response()

View File

@ -106,9 +106,8 @@ async def test_llm_api(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
@ -128,9 +127,8 @@ async def test_unknown_llm_api(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="unknown-api",
user_llm_prompt=None,
)
@ -170,9 +168,8 @@ async def test_multiple_llm_apis(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=["assist", "my-api"],
user_llm_prompt=None,
)
@ -192,9 +189,8 @@ async def test_template_error(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt="{{ invalid_syntax",
)
@ -217,9 +213,8 @@ async def test_template_variables(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=(
"The instance name is {{ ha_name }}. "
@ -249,11 +244,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
@ -273,11 +268,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)
assert chat_log.extra_system_prompt == extra_system_prompt
@ -290,11 +285,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
@ -314,11 +309,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api=None,
user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
)
assert chat_log.extra_system_prompt == extra_system_prompt2
@ -357,9 +352,8 @@ async def test_tool_call(
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
@ -434,9 +428,8 @@ async def test_tool_call_exception(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)
@ -595,9 +588,8 @@ async def test_add_delta_content_stream(
) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
await chat_log.async_provide_llm_data(
mock_conversation_input.as_llm_context("test"),
user_llm_hass_api="assist",
user_llm_prompt=None,
)