mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 20:27:08 +00:00
Simplify ChatLog dependencies (#146351)
This commit is contained in:
parent
5f5869ffc6
commit
fa21269f0d
@ -366,11 +366,11 @@ class AnthropicConversationEntity(
|
|||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
DOMAIN,
|
user_input.as_llm_context(DOMAIN),
|
||||||
user_input,
|
|
||||||
options.get(CONF_LLM_HASS_API),
|
options.get(CONF_LLM_HASS_API),
|
||||||
options.get(CONF_PROMPT),
|
options.get(CONF_PROMPT),
|
||||||
|
user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
except conversation.ConverseError as err:
|
except conversation.ConverseError as err:
|
||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
@ -14,12 +14,11 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
from homeassistant.helpers import chat_session, intent, llm, template
|
from homeassistant.helpers import chat_session, frame, intent, llm, template
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
from homeassistant.util.json import JsonObjectType
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
|
||||||
from . import trace
|
from . import trace
|
||||||
from .const import DOMAIN
|
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
|
|
||||||
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
|
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
|
||||||
@ -359,7 +358,7 @@ class ChatLog:
|
|||||||
self,
|
self,
|
||||||
llm_context: llm.LLMContext,
|
llm_context: llm.LLMContext,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
language: str,
|
language: str | None,
|
||||||
user_name: str | None = None,
|
user_name: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
@ -373,7 +372,7 @@ class ChatLog:
|
|||||||
)
|
)
|
||||||
except TemplateError as err:
|
except TemplateError as err:
|
||||||
LOGGER.error("Error rendering prompt: %s", err)
|
LOGGER.error("Error rendering prompt: %s", err)
|
||||||
intent_response = intent.IntentResponse(language=language)
|
intent_response = intent.IntentResponse(language=language or "")
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
"Sorry, I had a problem with my template",
|
"Sorry, I had a problem with my template",
|
||||||
@ -392,14 +391,25 @@ class ChatLog:
|
|||||||
user_llm_prompt: str | None = None,
|
user_llm_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the LLM system prompt."""
|
"""Set the LLM system prompt."""
|
||||||
llm_context = llm.LLMContext(
|
frame.report_usage(
|
||||||
platform=conversing_domain,
|
"ChatLog.async_update_llm_data",
|
||||||
context=user_input.context,
|
breaks_in_ha_version="2026.1",
|
||||||
language=user_input.language,
|
)
|
||||||
assistant=DOMAIN,
|
return await self.async_provide_llm_data(
|
||||||
device_id=user_input.device_id,
|
llm_context=user_input.as_llm_context(conversing_domain),
|
||||||
|
user_llm_hass_api=user_llm_hass_api,
|
||||||
|
user_llm_prompt=user_llm_prompt,
|
||||||
|
user_extra_system_prompt=user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_provide_llm_data(
|
||||||
|
self,
|
||||||
|
llm_context: llm.LLMContext,
|
||||||
|
user_llm_hass_api: str | list[str] | None = None,
|
||||||
|
user_llm_prompt: str | None = None,
|
||||||
|
user_extra_system_prompt: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set the LLM system prompt."""
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
|
|
||||||
if user_llm_hass_api:
|
if user_llm_hass_api:
|
||||||
@ -413,10 +423,12 @@ class ChatLog:
|
|||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
"Error getting LLM API %s for %s: %s",
|
"Error getting LLM API %s for %s: %s",
|
||||||
user_llm_hass_api,
|
user_llm_hass_api,
|
||||||
conversing_domain,
|
llm_context.platform,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(
|
||||||
|
language=llm_context.language or ""
|
||||||
|
)
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
"Error preparing LLM API",
|
"Error preparing LLM API",
|
||||||
@ -430,10 +442,10 @@ class ChatLog:
|
|||||||
user_name: str | None = None
|
user_name: str | None = None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_input.context
|
llm_context.context
|
||||||
and user_input.context.user_id
|
and llm_context.context.user_id
|
||||||
and (
|
and (
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
user := await self.hass.auth.async_get_user(llm_context.context.user_id)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
user_name = user.name
|
user_name = user.name
|
||||||
@ -443,7 +455,7 @@ class ChatLog:
|
|||||||
await self._async_expand_prompt_template(
|
await self._async_expand_prompt_template(
|
||||||
llm_context,
|
llm_context,
|
||||||
(user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
(user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||||
user_input.language,
|
llm_context.language,
|
||||||
user_name,
|
user_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -455,14 +467,14 @@ class ChatLog:
|
|||||||
await self._async_expand_prompt_template(
|
await self._async_expand_prompt_template(
|
||||||
llm_context,
|
llm_context,
|
||||||
llm.BASE_PROMPT,
|
llm.BASE_PROMPT,
|
||||||
user_input.language,
|
llm_context.language,
|
||||||
user_name,
|
user_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if extra_system_prompt := (
|
if extra_system_prompt := (
|
||||||
# Take new system prompt if one was given
|
# Take new system prompt if one was given
|
||||||
user_input.extra_system_prompt or self.extra_system_prompt
|
user_extra_system_prompt or self.extra_system_prompt
|
||||||
):
|
):
|
||||||
prompt_parts.append(extra_system_prompt)
|
prompt_parts.append(extra_system_prompt)
|
||||||
|
|
||||||
|
@ -7,7 +7,9 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from homeassistant.core import Context
|
from homeassistant.core import Context
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent, llm
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -56,6 +58,16 @@ class ConversationInput:
|
|||||||
"extra_system_prompt": self.extra_system_prompt,
|
"extra_system_prompt": self.extra_system_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def as_llm_context(self, conversing_domain: str) -> llm.LLMContext:
|
||||||
|
"""Return input as an LLM context."""
|
||||||
|
return llm.LLMContext(
|
||||||
|
platform=conversing_domain,
|
||||||
|
context=self.context,
|
||||||
|
language=self.language,
|
||||||
|
assistant=DOMAIN,
|
||||||
|
device_id=self.device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class ConversationResult:
|
class ConversationResult:
|
||||||
|
@ -73,11 +73,11 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
DOMAIN,
|
user_input.as_llm_context(DOMAIN),
|
||||||
user_input,
|
|
||||||
options.get(CONF_LLM_HASS_API),
|
options.get(CONF_LLM_HASS_API),
|
||||||
options.get(CONF_PROMPT),
|
options.get(CONF_PROMPT),
|
||||||
|
user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
except conversation.ConverseError as err:
|
except conversation.ConverseError as err:
|
||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
@ -219,11 +219,11 @@ class OllamaConversationEntity(
|
|||||||
settings = {**self.entry.data, **self.entry.options}
|
settings = {**self.entry.data, **self.entry.options}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
DOMAIN,
|
user_input.as_llm_context(DOMAIN),
|
||||||
user_input,
|
|
||||||
settings.get(CONF_LLM_HASS_API),
|
settings.get(CONF_LLM_HASS_API),
|
||||||
settings.get(CONF_PROMPT),
|
settings.get(CONF_PROMPT),
|
||||||
|
user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
except conversation.ConverseError as err:
|
except conversation.ConverseError as err:
|
||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
@ -279,11 +279,11 @@ class OpenAIConversationEntity(
|
|||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
DOMAIN,
|
user_input.as_llm_context(DOMAIN),
|
||||||
user_input,
|
|
||||||
options.get(CONF_LLM_HASS_API),
|
options.get(CONF_LLM_HASS_API),
|
||||||
options.get(CONF_PROMPT),
|
options.get(CONF_PROMPT),
|
||||||
|
user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
except conversation.ConverseError as err:
|
except conversation.ConverseError as err:
|
||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
@ -1779,11 +1779,11 @@ async def test_chat_log_tts_streaming(
|
|||||||
conversation_input,
|
conversation_input,
|
||||||
) as chat_log,
|
) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
conversation_input.as_llm_context("test"),
|
||||||
user_input=conversation_input,
|
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
|
user_extra_system_prompt=conversation_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
async for _content in chat_log.async_add_delta_content_stream(
|
async for _content in chat_log.async_add_delta_content_stream(
|
||||||
agent_id, stream_llm_response()
|
agent_id, stream_llm_response()
|
||||||
|
@ -106,9 +106,8 @@ async def test_llm_api(
|
|||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
@ -128,9 +127,8 @@ async def test_unknown_llm_api(
|
|||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
pytest.raises(ConverseError) as exc_info,
|
pytest.raises(ConverseError) as exc_info,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api="unknown-api",
|
user_llm_hass_api="unknown-api",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
@ -170,9 +168,8 @@ async def test_multiple_llm_apis(
|
|||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=["assist", "my-api"],
|
user_llm_hass_api=["assist", "my-api"],
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
@ -192,9 +189,8 @@ async def test_template_error(
|
|||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
pytest.raises(ConverseError) as exc_info,
|
pytest.raises(ConverseError) as exc_info,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt="{{ invalid_syntax",
|
user_llm_prompt="{{ invalid_syntax",
|
||||||
)
|
)
|
||||||
@ -217,9 +213,8 @@ async def test_template_variables(
|
|||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=(
|
user_llm_prompt=(
|
||||||
"The instance name is {{ ha_name }}. "
|
"The instance name is {{ ha_name }}. "
|
||||||
@ -249,11 +244,11 @@ async def test_extra_systen_prompt(
|
|||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
|
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
chat_log.async_add_assistant_content_without_tools(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(
|
AssistantContent(
|
||||||
@ -273,11 +268,11 @@ async def test_extra_systen_prompt(
|
|||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
|
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt
|
assert chat_log.extra_system_prompt == extra_system_prompt
|
||||||
@ -290,11 +285,11 @@ async def test_extra_systen_prompt(
|
|||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
|
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
chat_log.async_add_assistant_content_without_tools(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(
|
AssistantContent(
|
||||||
@ -314,11 +309,11 @@ async def test_extra_systen_prompt(
|
|||||||
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api=None,
|
user_llm_hass_api=None,
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
|
user_extra_system_prompt=mock_conversation_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chat_log.extra_system_prompt == extra_system_prompt2
|
assert chat_log.extra_system_prompt == extra_system_prompt2
|
||||||
@ -357,9 +352,8 @@ async def test_tool_call(
|
|||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
@ -434,9 +428,8 @@ async def test_tool_call_exception(
|
|||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
):
|
):
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
@ -595,9 +588,8 @@ async def test_add_delta_content_stream(
|
|||||||
) as chat_log,
|
) as chat_log,
|
||||||
):
|
):
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
conversing_domain="test",
|
mock_conversation_input.as_llm_context("test"),
|
||||||
user_input=mock_conversation_input,
|
|
||||||
user_llm_hass_api="assist",
|
user_llm_hass_api="assist",
|
||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user