Simplify ChatLog dependencies (#146351)

This commit is contained in:
Paulus Schoutsen 2025-06-15 17:41:15 -04:00 committed by GitHub
parent 5f5869ffc6
commit fa21269f0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 86 additions and 70 deletions

View File

@ -366,11 +366,11 @@ class AnthropicConversationEntity(
options = self.entry.options options = self.entry.options
try: try:
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
DOMAIN, user_input.as_llm_context(DOMAIN),
user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT), options.get(CONF_PROMPT),
user_input.extra_system_prompt,
) )
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()

View File

@ -14,12 +14,11 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template from homeassistant.helpers import chat_session, frame, intent, llm, template
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from . import trace from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs") DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
@ -359,7 +358,7 @@ class ChatLog:
self, self,
llm_context: llm.LLMContext, llm_context: llm.LLMContext,
prompt: str, prompt: str,
language: str, language: str | None,
user_name: str | None = None, user_name: str | None = None,
) -> str: ) -> str:
try: try:
@ -373,7 +372,7 @@ class ChatLog:
) )
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=language) intent_response = intent.IntentResponse(language=language or "")
intent_response.async_set_error( intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem with my template", "Sorry, I had a problem with my template",
@ -392,14 +391,25 @@ class ChatLog:
user_llm_prompt: str | None = None, user_llm_prompt: str | None = None,
) -> None: ) -> None:
"""Set the LLM system prompt.""" """Set the LLM system prompt."""
llm_context = llm.LLMContext( frame.report_usage(
platform=conversing_domain, "ChatLog.async_update_llm_data",
context=user_input.context, breaks_in_ha_version="2026.1",
language=user_input.language, )
assistant=DOMAIN, return await self.async_provide_llm_data(
device_id=user_input.device_id, llm_context=user_input.as_llm_context(conversing_domain),
user_llm_hass_api=user_llm_hass_api,
user_llm_prompt=user_llm_prompt,
user_extra_system_prompt=user_input.extra_system_prompt,
) )
async def async_provide_llm_data(
self,
llm_context: llm.LLMContext,
user_llm_hass_api: str | list[str] | None = None,
user_llm_prompt: str | None = None,
user_extra_system_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
if user_llm_hass_api: if user_llm_hass_api:
@ -413,10 +423,12 @@ class ChatLog:
LOGGER.error( LOGGER.error(
"Error getting LLM API %s for %s: %s", "Error getting LLM API %s for %s: %s",
user_llm_hass_api, user_llm_hass_api,
conversing_domain, llm_context.platform,
err, err,
) )
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(
language=llm_context.language or ""
)
intent_response.async_set_error( intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
"Error preparing LLM API", "Error preparing LLM API",
@ -430,10 +442,10 @@ class ChatLog:
user_name: str | None = None user_name: str | None = None
if ( if (
user_input.context llm_context.context
and user_input.context.user_id and llm_context.context.user_id
and ( and (
user := await self.hass.auth.async_get_user(user_input.context.user_id) user := await self.hass.auth.async_get_user(llm_context.context.user_id)
) )
): ):
user_name = user.name user_name = user.name
@ -443,7 +455,7 @@ class ChatLog:
await self._async_expand_prompt_template( await self._async_expand_prompt_template(
llm_context, llm_context,
(user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT), (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
user_input.language, llm_context.language,
user_name, user_name,
) )
) )
@ -455,14 +467,14 @@ class ChatLog:
await self._async_expand_prompt_template( await self._async_expand_prompt_template(
llm_context, llm_context,
llm.BASE_PROMPT, llm.BASE_PROMPT,
user_input.language, llm_context.language,
user_name, user_name,
) )
) )
if extra_system_prompt := ( if extra_system_prompt := (
# Take new system prompt if one was given # Take new system prompt if one was given
user_input.extra_system_prompt or self.extra_system_prompt user_extra_system_prompt or self.extra_system_prompt
): ):
prompt_parts.append(extra_system_prompt) prompt_parts.append(extra_system_prompt)

View File

@ -7,7 +7,9 @@ from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
from homeassistant.core import Context from homeassistant.core import Context
from homeassistant.helpers import intent from homeassistant.helpers import intent, llm
from .const import DOMAIN
@dataclass(frozen=True) @dataclass(frozen=True)
@ -56,6 +58,16 @@ class ConversationInput:
"extra_system_prompt": self.extra_system_prompt, "extra_system_prompt": self.extra_system_prompt,
} }
def as_llm_context(self, conversing_domain: str) -> llm.LLMContext:
"""Return input as an LLM context."""
return llm.LLMContext(
platform=conversing_domain,
context=self.context,
language=self.language,
assistant=DOMAIN,
device_id=self.device_id,
)
@dataclass(slots=True) @dataclass(slots=True)
class ConversationResult: class ConversationResult:

View File

@ -73,11 +73,11 @@ class GoogleGenerativeAIConversationEntity(
options = self.entry.options options = self.entry.options
try: try:
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
DOMAIN, user_input.as_llm_context(DOMAIN),
user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT), options.get(CONF_PROMPT),
user_input.extra_system_prompt,
) )
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()

View File

@ -219,11 +219,11 @@ class OllamaConversationEntity(
settings = {**self.entry.data, **self.entry.options} settings = {**self.entry.data, **self.entry.options}
try: try:
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
DOMAIN, user_input.as_llm_context(DOMAIN),
user_input,
settings.get(CONF_LLM_HASS_API), settings.get(CONF_LLM_HASS_API),
settings.get(CONF_PROMPT), settings.get(CONF_PROMPT),
user_input.extra_system_prompt,
) )
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()

View File

@ -279,11 +279,11 @@ class OpenAIConversationEntity(
options = self.entry.options options = self.entry.options
try: try:
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
DOMAIN, user_input.as_llm_context(DOMAIN),
user_input,
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
options.get(CONF_PROMPT), options.get(CONF_PROMPT),
user_input.extra_system_prompt,
) )
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()

View File

@ -1779,11 +1779,11 @@ async def test_chat_log_tts_streaming(
conversation_input, conversation_input,
) as chat_log, ) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", conversation_input.as_llm_context("test"),
user_input=conversation_input,
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
user_extra_system_prompt=conversation_input.extra_system_prompt,
) )
async for _content in chat_log.async_add_delta_content_stream( async for _content in chat_log.async_add_delta_content_stream(
agent_id, stream_llm_response() agent_id, stream_llm_response()

View File

@ -106,9 +106,8 @@ async def test_llm_api(
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
@ -128,9 +127,8 @@ async def test_unknown_llm_api(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info, pytest.raises(ConverseError) as exc_info,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api="unknown-api", user_llm_hass_api="unknown-api",
user_llm_prompt=None, user_llm_prompt=None,
) )
@ -170,9 +168,8 @@ async def test_multiple_llm_apis(
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=["assist", "my-api"], user_llm_hass_api=["assist", "my-api"],
user_llm_prompt=None, user_llm_prompt=None,
) )
@ -192,9 +189,8 @@ async def test_template_error(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
pytest.raises(ConverseError) as exc_info, pytest.raises(ConverseError) as exc_info,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt="{{ invalid_syntax", user_llm_prompt="{{ invalid_syntax",
) )
@ -217,9 +213,8 @@ async def test_template_variables(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=( user_llm_prompt=(
"The instance name is {{ ha_name }}. " "The instance name is {{ ha_name }}. "
@ -249,11 +244,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
) )
chat_log.async_add_assistant_content_without_tools( chat_log.async_add_assistant_content_without_tools(
AssistantContent( AssistantContent(
@ -273,11 +268,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
) )
assert chat_log.extra_system_prompt == extra_system_prompt assert chat_log.extra_system_prompt == extra_system_prompt
@ -290,11 +285,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
) )
chat_log.async_add_assistant_content_without_tools( chat_log.async_add_assistant_content_without_tools(
AssistantContent( AssistantContent(
@ -314,11 +309,11 @@ async def test_extra_systen_prompt(
chat_session.async_get_chat_session(hass, conversation_id) as session, chat_session.async_get_chat_session(hass, conversation_id) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api=None, user_llm_hass_api=None,
user_llm_prompt=None, user_llm_prompt=None,
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
) )
assert chat_log.extra_system_prompt == extra_system_prompt2 assert chat_log.extra_system_prompt == extra_system_prompt2
@ -357,9 +352,8 @@ async def test_tool_call(
chat_session.async_get_chat_session(hass) as session, chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
@ -434,9 +428,8 @@ async def test_tool_call_exception(
async_get_chat_log(hass, session, mock_conversation_input) as chat_log, async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
): ):
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )
@ -595,9 +588,8 @@ async def test_add_delta_content_stream(
) as chat_log, ) as chat_log,
): ):
mock_get_tools.return_value = [mock_tool] mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data( await chat_log.async_provide_llm_data(
conversing_domain="test", mock_conversation_input.as_llm_context("test"),
user_input=mock_conversation_input,
user_llm_hass_api="assist", user_llm_hass_api="assist",
user_llm_prompt=None, user_llm_prompt=None,
) )