Standardize conversation.async_process method (#140125)

This commit is contained in:
Paulus Schoutsen 2025-03-10 15:15:10 -04:00 committed by GitHub
parent 1665d9474f
commit 49a62d5294
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 55 additions and 93 deletions

View File

@ -30,7 +30,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import AnthropicConfigEntry
@ -226,18 +226,6 @@ class AnthropicConversationEntity(
self.entry.add_update_listener(self._async_entry_update_listener)
)
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,

View File

@ -42,7 +42,6 @@ from homeassistant.components.homeassistant.exposed_entities import (
from homeassistant.const import EVENT_STATE_CHANGED, MATCH_ALL
from homeassistant.helpers import (
area_registry as ar,
chat_session,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
@ -56,7 +55,7 @@ from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.util import language as language_util
from homeassistant.util.json import JsonObjectType, json_loads_object
from .chat_log import AssistantContent, async_get_chat_log
from .chat_log import AssistantContent, ChatLog
from .const import (
DATA_DEFAULT_ENTITY,
DEFAULT_EXPOSED_ATTRIBUTES,
@ -332,49 +331,46 @@ class DefaultAgent(ConversationEntity):
return result
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
async def _async_handle_message(
self,
user_input: ConversationInput,
chat_log: ChatLog,
) -> ConversationResult:
"""Handle a message."""
response: intent.IntentResponse | None = None
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
async_get_chat_log(self.hass, session, user_input) as chat_log,
):
# Check if a trigger matched
if trigger_result := await self.async_recognize_sentence_trigger(
user_input
):
# Process callbacks and get response
response_text = await self._handle_trigger_result(
trigger_result, user_input
)
# Convert to conversation result
response = intent.IntentResponse(
language=user_input.language or self.hass.config.language
)
response.response_type = intent.IntentResponseType.ACTION_DONE
response.async_set_speech(response_text)
if response is None:
# Match intents
intent_result = await self.async_recognize_intent(user_input)
response = await self._async_process_intent_result(
intent_result, user_input
)
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id=user_input.agent_id,
content=speech,
)
# Check if a trigger matched
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
# Process callbacks and get response
response_text = await self._handle_trigger_result(
trigger_result, user_input
)
return ConversationResult(
response=response, conversation_id=session.conversation_id
# Convert to conversation result
response = intent.IntentResponse(
language=user_input.language or self.hass.config.language
)
response.response_type = intent.IntentResponseType.ACTION_DONE
response.async_set_speech(response_text)
if response is None:
# Match intents
intent_result = await self.async_recognize_intent(user_input)
response = await self._async_process_intent_result(
intent_result, user_input
)
speech: str = response.speech.get("plain", {}).get("speech", "")
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id=user_input.agent_id,
content=speech,
)
)
return ConversationResult(
response=response, conversation_id=chat_log.conversation_id
)
async def _async_process_intent_result(
self,

View File

@ -4,9 +4,11 @@ from abc import abstractmethod
from typing import Literal, final
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.helpers.chat_session import async_get_chat_session
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .chat_log import ChatLog, async_get_chat_log
from .const import ConversationEntityFeature
from .models import ConversationInput, ConversationResult
@ -51,9 +53,21 @@ class ConversationEntity(RestoreEntity):
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
@abstractmethod
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
with (
async_get_chat_session(self.hass, user_input.conversation_id) as session,
async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: ConversationInput,
chat_log: ChatLog,
) -> ConversationResult:
"""Call the API."""
raise NotImplementedError
async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language."""

View File

@ -25,7 +25,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import (
@ -264,18 +264,6 @@ class GoogleGenerativeAIConversationEntity(
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,

View File

@ -15,7 +15,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.helpers import intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import (
@ -206,18 +206,6 @@ class OllamaConversationEntity(
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,

View File

@ -24,7 +24,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenAIConfigEntry
@ -223,18 +223,6 @@ class OpenAIConversationEntity(
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,