mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 20:27:08 +00:00
Simplify ChatLog dependencies (#146351)
This commit is contained in:
parent
5f5869ffc6
commit
fa21269f0d
@ -366,11 +366,11 @@ class AnthropicConversationEntity(
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
options.get(CONF_PROMPT),
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
@ -14,12 +14,11 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import chat_session, intent, llm, template
|
||||
from homeassistant.helpers import chat_session, frame, intent, llm, template
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import trace
|
||||
from .const import DOMAIN
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
|
||||
@ -359,7 +358,7 @@ class ChatLog:
|
||||
self,
|
||||
llm_context: llm.LLMContext,
|
||||
prompt: str,
|
||||
language: str,
|
||||
language: str | None,
|
||||
user_name: str | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
@ -373,7 +372,7 @@ class ChatLog:
|
||||
)
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=language)
|
||||
intent_response = intent.IntentResponse(language=language or "")
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem with my template",
|
||||
@ -392,14 +391,25 @@ class ChatLog:
|
||||
user_llm_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Set the LLM system prompt."""
|
||||
llm_context = llm.LLMContext(
|
||||
platform=conversing_domain,
|
||||
context=user_input.context,
|
||||
language=user_input.language,
|
||||
assistant=DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
frame.report_usage(
|
||||
"ChatLog.async_update_llm_data",
|
||||
breaks_in_ha_version="2026.1",
|
||||
)
|
||||
return await self.async_provide_llm_data(
|
||||
llm_context=user_input.as_llm_context(conversing_domain),
|
||||
user_llm_hass_api=user_llm_hass_api,
|
||||
user_llm_prompt=user_llm_prompt,
|
||||
user_extra_system_prompt=user_input.extra_system_prompt,
|
||||
)
|
||||
|
||||
async def async_provide_llm_data(
|
||||
self,
|
||||
llm_context: llm.LLMContext,
|
||||
user_llm_hass_api: str | list[str] | None = None,
|
||||
user_llm_prompt: str | None = None,
|
||||
user_extra_system_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Set the LLM system prompt."""
|
||||
llm_api: llm.APIInstance | None = None
|
||||
|
||||
if user_llm_hass_api:
|
||||
@ -413,10 +423,12 @@ class ChatLog:
|
||||
LOGGER.error(
|
||||
"Error getting LLM API %s for %s: %s",
|
||||
user_llm_hass_api,
|
||||
conversing_domain,
|
||||
llm_context.platform,
|
||||
err,
|
||||
)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response = intent.IntentResponse(
|
||||
language=llm_context.language or ""
|
||||
)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Error preparing LLM API",
|
||||
@ -430,10 +442,10 @@ class ChatLog:
|
||||
user_name: str | None = None
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
llm_context.context
|
||||
and llm_context.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
user := await self.hass.auth.async_get_user(llm_context.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
@ -443,7 +455,7 @@ class ChatLog:
|
||||
await self._async_expand_prompt_template(
|
||||
llm_context,
|
||||
(user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
user_input.language,
|
||||
llm_context.language,
|
||||
user_name,
|
||||
)
|
||||
)
|
||||
@ -455,14 +467,14 @@ class ChatLog:
|
||||
await self._async_expand_prompt_template(
|
||||
llm_context,
|
||||
llm.BASE_PROMPT,
|
||||
user_input.language,
|
||||
llm_context.language,
|
||||
user_name,
|
||||
)
|
||||
)
|
||||
|
||||
if extra_system_prompt := (
|
||||
# Take new system prompt if one was given
|
||||
user_input.extra_system_prompt or self.extra_system_prompt
|
||||
user_extra_system_prompt or self.extra_system_prompt
|
||||
):
|
||||
prompt_parts.append(extra_system_prompt)
|
||||
|
||||
|
@ -7,7 +7,9 @@ from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -56,6 +58,16 @@ class ConversationInput:
|
||||
"extra_system_prompt": self.extra_system_prompt,
|
||||
}
|
||||
|
||||
def as_llm_context(self, conversing_domain: str) -> llm.LLMContext:
|
||||
"""Return input as an LLM context."""
|
||||
return llm.LLMContext(
|
||||
platform=conversing_domain,
|
||||
context=self.context,
|
||||
language=self.language,
|
||||
assistant=DOMAIN,
|
||||
device_id=self.device_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversationResult:
|
||||
|
@ -73,11 +73,11 @@ class GoogleGenerativeAIConversationEntity(
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
options.get(CONF_PROMPT),
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
@ -219,11 +219,11 @@ class OllamaConversationEntity(
|
||||
settings = {**self.entry.data, **self.entry.options}
|
||||
|
||||
try:
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
settings.get(CONF_LLM_HASS_API),
|
||||
settings.get(CONF_PROMPT),
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
@ -279,11 +279,11 @@ class OpenAIConversationEntity(
|
||||
options = self.entry.options
|
||||
|
||||
try:
|
||||
await chat_log.async_update_llm_data(
|
||||
DOMAIN,
|
||||
user_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
options.get(CONF_PROMPT),
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
@ -1779,11 +1779,11 @@ async def test_chat_log_tts_streaming(
|
||||
conversation_input,
|
||||
) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
user_extra_system_prompt=conversation_input.extra_system_prompt,
|
||||
)
|
||||
async for _content in chat_log.async_add_delta_content_stream(
|
||||
agent_id, stream_llm_response()
|
||||
|
@ -106,9 +106,8 @@ async def test_llm_api(
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
@ -128,9 +127,8 @@ async def test_unknown_llm_api(
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
pytest.raises(ConverseError) as exc_info,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api="unknown-api",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
@ -170,9 +168,8 @@ async def test_multiple_llm_apis(
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=["assist", "my-api"],
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
@ -192,9 +189,8 @@ async def test_template_error(
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
pytest.raises(ConverseError) as exc_info,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt="{{ invalid_syntax",
|
||||
)
|
||||
@ -217,9 +213,8 @@ async def test_template_variables(
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=(
|
||||
"The instance name is {{ ha_name }}. "
|
||||
@ -249,11 +244,11 @@ async def test_extra_systen_prompt(
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||
)
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
@ -273,11 +268,11 @@ async def test_extra_systen_prompt(
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||
@ -290,11 +285,11 @@ async def test_extra_systen_prompt(
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||
)
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
@ -314,11 +309,11 @@ async def test_extra_systen_prompt(
|
||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api=None,
|
||||
user_llm_prompt=None,
|
||||
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||
)
|
||||
|
||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||
@ -357,9 +352,8 @@ async def test_tool_call(
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
@ -434,9 +428,8 @@ async def test_tool_call_exception(
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
@ -595,9 +588,8 @@ async def test_add_delta_content_stream(
|
||||
) as chat_log,
|
||||
):
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
await chat_log.async_provide_llm_data(
|
||||
mock_conversation_input.as_llm_context("test"),
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user