diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 81cc7ab8a73..db2df9cddd3 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -11,18 +11,15 @@ import google.generativeai as genai from google.generativeai import protos import google.generativeai.types as genai_types from google.protobuf.json_format import MessageToDict -import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation -from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.helpers import device_registry as dr, intent, llm, template +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.util import ulid as ulid_util from .const import ( CONF_CHAT_MODEL, @@ -152,6 +149,17 @@ def _escape_decode(value: Any) -> Any: return value +def _chat_message_convert( + message: conversation.Content | conversation.NativeContent[genai_types.ContentDict], +) -> genai_types.ContentDict: + """Convert any native chat message for this agent to the native format.""" + if message.role == "native": + return message.content + + role = "model" if message.role == "assistant" else message.role + return {"role": role, "parts": message.content} + + class GoogleGenerativeAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -163,7 +171,6 @@ class GoogleGenerativeAIConversationEntity( def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" self.entry = entry - self.history: dict[str, list[genai_types.ContentType]] = {} self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, @@ -202,49 +209,37 @@ class GoogleGenerativeAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - result = conversation.ConversationResult( - response=intent.IntentResponse(language=user_input.language), - conversation_id=user_input.conversation_id or ulid_util.ulid_now(), - ) - assert result.conversation_id + async with conversation.async_get_chat_session( + self.hass, user_input + ) as session: + return await self._async_handle_message(user_input, session) - llm_context = llm.LLMContext( - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - llm_api: llm.APIInstance | None = None - tools: list[dict[str, Any]] | None = None - if self.entry.options.get(CONF_LLM_HASS_API): - try: - llm_api = await llm.async_get_api( - self.hass, - self.entry.options[CONF_LLM_HASS_API], - llm_context, - ) - except HomeAssistantError as err: - LOGGER.error("Error getting LLM API: %s", err) - result.response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Error preparing LLM API: {err}", - ) - return result - tools = [ - _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools - ] + async def _async_handle_message( + self, + user_input: conversation.ConversationInput, + session: conversation.ChatSession[genai_types.ContentDict], + ) -> conversation.ConversationResult: + """Call the API.""" + + assert user_input.agent_id + options = self.entry.options try: - prompt = await self._async_render_prompt(user_input, llm_api, llm_context) - except TemplateError as err: - LOGGER.error("Error rendering prompt: %s", err) - result.response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", + await session.async_update_llm_data( + DOMAIN, + user_input, + options.get(CONF_LLM_HASS_API), + options.get(CONF_PROMPT), ) - return result + except conversation.ConverseError as err: + return err.as_conversation_result() + + tools: list[dict[str, Any]] | None = None + if session.llm_api: + tools = [ + _format_tool(tool, session.llm_api.custom_serializer) + for tool in session.llm_api.tools + ] model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) # Gemini 1.0 doesn't support system_instruction while 1.5 does. @@ -254,6 +249,9 @@ class GoogleGenerativeAIConversationEntity( "gemini-1.0" not in model_name and "gemini-pro" not in model_name ) + prompt, *messages = [ + _chat_message_convert(message) for message in session.async_get_messages() + ] model = genai.GenerativeModel( model_name=model_name, generation_config={ @@ -281,27 +279,15 @@ class GoogleGenerativeAIConversationEntity( ), }, tools=tools or None, - system_instruction=prompt if supports_system_instruction else None, + system_instruction=prompt["parts"] if supports_system_instruction else None, ) - messages = self.history.get(result.conversation_id, []) if not supports_system_instruction: - if not messages: - messages = [{}, {"role": "model", "parts": "Ok"}] - messages[0] = {"role": "user", "parts": prompt} - - LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) - trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, - { - # Make a copy to attach it to the trace event. - "messages": messages[:] - if supports_system_instruction - else messages[2:], - "prompt": prompt, - "tools": [*llm_api.tools] if llm_api else None, - }, - ) + messages = [ + {"role": "user", "parts": prompt["parts"]}, + {"role": "model", "parts": "Ok"}, + *messages, + ] chat = model.start_chat(history=messages) chat_request = user_input.text @@ -326,24 +312,30 @@ class GoogleGenerativeAIConversationEntity( f"Sorry, I had a problem talking to Google Generative AI: {err}" ) - result.response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - error, - ) - return result + raise HomeAssistantError(error) from err LOGGER.debug("Response: %s", chat_response.parts) if not chat_response.parts: - result.response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - "Sorry, I had a problem getting a response from Google Generative AI.", + raise HomeAssistantError( + "Sorry, I had a problem getting a response from Google Generative AI." ) - return result - self.history[result.conversation_id] = chat.history + content = " ".join( + [part.text.strip() for part in chat_response.parts if part.text] + ) + if content: + session.async_add_message( + conversation.Content( + role="assistant", + agent_id=user_input.agent_id, + content=content, + ) + ) + function_calls = [ part.function_call for part in chat_response.parts if part.function_call ] - if not function_calls or not llm_api: + + if not function_calls or not session.llm_api: break tool_responses = [] @@ -351,16 +343,8 @@ class GoogleGenerativeAIConversationEntity( tool_call = MessageToDict(function_call._pb) # noqa: SLF001 tool_name = tool_call["name"] tool_args = _escape_decode(tool_call["args"]) - LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args) tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args) - try: - function_response = await llm_api.async_call_tool(tool_input) - except (HomeAssistantError, vol.Invalid) as e: - function_response = {"error": type(e).__name__} - if str(e): - function_response["error_text"] = str(e) - - LOGGER.debug("Tool response: %s", function_response) + function_response = await session.async_call_tool(tool_input) tool_responses.append( protos.Part( function_response=protos.FunctionResponse( @@ -369,47 +353,20 @@ class GoogleGenerativeAIConversationEntity( ) ) chat_request = protos.Content(parts=tool_responses) + session.async_add_message( + conversation.NativeContent( + agent_id=user_input.agent_id, + content=chat_request, + ) + ) - result.response.async_set_speech( + response = intent.IntentResponse(language=user_input.language) + response.async_set_speech( " ".join([part.text.strip() for part in chat_response.parts if part.text]) ) - return result - - async def _async_render_prompt( - self, - user_input: conversation.ConversationInput, - llm_api: llm.APIInstance | None, - llm_context: llm.LLMContext, - ) -> str: - user_name: str | None = None - if ( - user_input.context - and user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user(user_input.context.user_id) - ) - ): - user_name = user.name - - parts = [ - template.Template( - llm.BASE_PROMPT - + self.entry.options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), - self.hass, - ).async_render( - { - "ha_name": self.hass.config.location_name, - "user_name": user_name, - "llm_context": llm_context, - }, - parse_result=False, - ) - ] - - if llm_api: - parts.append(llm_api.api_prompt) - - return "\n".join(parts) + return conversation.ConversationResult( + response=response, conversation_id=session.conversation_id + ) async def _async_entry_update_listener( self, hass: HomeAssistant, entry: ConfigEntry diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index 65238c5212a..21458abb7c8 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -42,6 +42,10 @@ 'parts': 'Ok', 'role': 'model', }), + dict({ + 'parts': '1st user request', + 'role': 'user', + }), ]), }), ), @@ -102,6 +106,10 @@ 'parts': '1st model response', 'role': 'model', }), + dict({ + 'parts': '2nd user request', + 'role': 'user', + }), ]), }), ), @@ -150,6 +158,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': '1st user request', + 'role': 'user', + }), ]), }), ), @@ -202,6 +214,10 @@ 'parts': '1st model response', 'role': 'model', }), + dict({ + 'parts': '2nd user request', + 'role': 'user', + }), ]), }), ), @@ -250,6 +266,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': 'hello', + 'role': 'user', + }), ]), }), ), @@ -298,6 +318,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': 'hello', + 'role': 'user', + }), ]), }), ), @@ -347,6 +371,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': 'hello', + 'role': 'user', + }), ]), }), ), @@ -396,6 +424,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': 'hello', + 'role': 'user', + }), ]), }), ), @@ -482,6 +514,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': 'Please call the test function', + 'role': 'user', + }), ]), }), ), @@ -558,6 +594,10 @@ ), dict({ 'history': list([ + dict({ + 'parts': 'Please call the test function', + 'role': 'user', + }), ]), }), ), diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index df0b11487d8..a87056275dc 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -208,6 +208,7 @@ async def test_function_call( chat_response = MagicMock() mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() + mock_part.text = "" mock_part.function_call = FunctionCall( name="test_tool", args={ @@ -284,8 +285,12 @@ async def test_function_call( ] # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] - assert "Answer in plain text" in detail_event["data"]["prompt"] - assert [t.name for t in detail_event["data"]["tools"]] == ["test_tool"] + assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] + assert [ + p.function_response.name + for p in detail_event["data"]["messages"][2]["content"].parts + if p.function_response + ] == ["test_tool"] @patch( @@ -315,6 +320,7 @@ async def test_function_call_without_parameters( chat_response = MagicMock() mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() + mock_part.text = "" mock_part.function_call = FunctionCall(name="test_tool", args={}) def tool_call( @@ -403,6 +409,7 @@ async def test_function_exception( chat_response = MagicMock() mock_chat.send_message_async.return_value = chat_response mock_part = MagicMock() + mock_part.text = "" mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1}) def tool_call( @@ -543,7 +550,7 @@ async def test_invalid_llm_api( assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "Error preparing LLM API: API invalid_llm_api not found" + "Error preparing LLM API" )