"""Conversation support for OpenAI.""" from typing import Literal from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.helpers import intent from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import OpenAIConfigEntry from .const import CONF_PROMPT, DOMAIN from .entity import OpenAIBaseLLMEntity # Max number of back and forth with the LLM to generate a response async def async_setup_entry( hass: HomeAssistant, config_entry: OpenAIConfigEntry, async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up conversation entities.""" for subentry in config_entry.subentries.values(): if subentry.subentry_type != "conversation": continue async_add_entities( [OpenAIConversationEntity(config_entry, subentry)], config_subentry_id=subentry.subentry_id, ) class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent, OpenAIBaseLLMEntity, ): """OpenAI conversation agent.""" _attr_supports_streaming = True def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" super().__init__(entry, subentry) if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) @property def supported_languages(self) -> list[str] | Literal["*"]: """Return a list of supported languages.""" return MATCH_ALL async def async_added_to_hass(self) -> None: """When entity is added to Home Assistant.""" await super().async_added_to_hass() assist_pipeline.async_migrate_engine( self.hass, "conversation", self.entry.entry_id, self.entity_id ) conversation.async_set_agent(self.hass, self.entry, self) async def async_will_remove_from_hass(self) -> None: """When entity will be removed from Home Assistant.""" conversation.async_unset_agent(self.hass, self.entry) await super().async_will_remove_from_hass() async def _async_handle_message( self, user_input: conversation.ConversationInput, chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Process the user input and call the API.""" options = self.subentry.data try: await chat_log.async_provide_llm_data( user_input.as_llm_context(DOMAIN), options.get(CONF_LLM_HASS_API), options.get(CONF_PROMPT), user_input.extra_system_prompt, ) except conversation.ConverseError as err: return err.as_conversation_result() await self._async_handle_chat_log(chat_log) intent_response = intent.IntentResponse(language=user_input.language) assert type(chat_log.content[-1]) is conversation.AssistantContent intent_response.async_set_speech(chat_log.content[-1].content or "") return conversation.ConversationResult( response=intent_response, conversation_id=chat_log.conversation_id, continue_conversation=chat_log.continue_conversation, )