From 15223b36794a6e2399e6434b339e2c9f57a424f3 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 9 Feb 2025 21:05:41 -0800 Subject: [PATCH] Update Ollama to use streaming API (#138177) * Update ollama to use streaming APIs * Remove unnecessary logging * Update ollama to use streaming APIs * Remove unnecessary logging * Update homeassistant/components/ollama/conversation.py Co-authored-by: Paulus Schoutsen --------- Co-authored-by: Paulus Schoutsen --- .../components/ollama/conversation.py | 87 ++++++---- tests/components/ollama/test_conversation.py | 154 +++++++++++++----- 2 files changed, 167 insertions(+), 74 deletions(-) diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index 2c83720f930..8ee275865a7 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable import json import logging from typing import Any, Literal @@ -123,7 +123,47 @@ def _convert_content( role=MessageRole.SYSTEM.value, content=chat_content.content, ) - raise ValueError(f"Unexpected content type: {type(chat_content)}") + raise TypeError(f"Unexpected content type: {type(chat_content)}") + + +async def _transform_stream( + result: AsyncGenerator[ollama.Message], +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + """Transform the response stream into HA format. + + An Ollama streaming response may come in chunks like this: + + response: message=Message(role="assistant", content="Paris") + response: message=Message(role="assistant", content=".") + response: message=Message(role="assistant", content=""), done: True, done_reason: "stop" + response: message=Message(role="assistant", tool_calls=[...]) + response: message=Message(role="assistant", content=""), done: True, done_reason: "stop" + + This generator conforms to the chatlog delta stream expectations in that it + yields deltas, then the role only once the response is done. + """ + + new_msg = True + async for response in result: + _LOGGER.debug("Received response: %s", response) + response_message = response["message"] + chunk: conversation.AssistantContentDeltaDict = {} + if new_msg: + new_msg = False + chunk["role"] = "assistant" + if (tool_calls := response_message.get("tool_calls")) is not None: + chunk["tool_calls"] = [ + llm.ToolInput( + tool_name=tool_call["function"]["name"], + tool_args=_parse_tool_args(tool_call["function"]["arguments"]), + ) + for tool_call in tool_calls + ] + if (content := response_message.get("content")) is not None: + chunk["content"] = content + if response_message.get("done"): + new_msg = True + yield chunk class OllamaConversationEntity( @@ -216,12 +256,12 @@ class OllamaConversationEntity( # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: - response = await client.chat( + response_generator = await client.chat( model=model, # Make a copy of the messages because we mutate the list later messages=list(message_history.messages), tools=tools, - stream=False, + stream=True, # keep_alive requires specifying unit. In this case, seconds keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, @@ -232,46 +272,25 @@ class OllamaConversationEntity( f"Sorry, I had a problem talking to the Ollama server: {err}" ) from err - response_message = response["message"] - content = response_message.get("content") - tool_calls = response_message.get("tool_calls") - message_history.messages.append( - ollama.Message( - role=response_message["role"], - content=content, - tool_calls=tool_calls, - ) - ) - tool_inputs = [ - llm.ToolInput( - tool_name=tool_call["function"]["name"], - tool_args=_parse_tool_args(tool_call["function"]["arguments"]), - ) - for tool_call in tool_calls or () - ] - message_history.messages.extend( [ - ollama.Message( - role=MessageRole.TOOL.value, - content=json.dumps(tool_response.tool_result), - ) - async for tool_response in chat_log.async_add_assistant_content( - conversation.AssistantContent( - agent_id=user_input.agent_id, - content=content, - tool_calls=tool_inputs or None, - ) + _convert_content(content) + async for content in chat_log.async_add_delta_content_stream( + user_input.agent_id, _transform_stream(response_generator) ) ] ) - if not tool_calls: + if not chat_log.unresponded_tool_results: break # Create intent response intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response_message["content"]) + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + raise TypeError( + f"Unexpected last message type: {type(chat_log.content[-1])}" + ) + intent_response.async_set_speech(chat_log.content[-1].content or "") return conversation.ConversationResult( response=intent_response, conversation_id=chat_log.conversation_id ) diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index df7c6beca72..db641ba703b 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -1,5 +1,6 @@ """Tests for the Ollama integration.""" +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, Mock, patch @@ -25,6 +26,14 @@ def mock_ulid_tools(): yield +async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]: + """Generate a response from the assistant.""" + if not isinstance(response, list): + response = [response] + for msg in response: + yield msg + + @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) async def test_chat( hass: HomeAssistant, @@ -42,7 +51,9 @@ async def test_chat( with patch( "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, + return_value=stream_generator( + {"message": {"role": "assistant", "content": "test response"}} + ), ) as mock_chat: result = await conversation.async_converse( hass, @@ -81,6 +92,53 @@ async def test_chat( assert "Current time is" in detail_event["data"]["messages"][0]["content"] +async def test_chat_stream( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test chat messages are assembled across streamed responses.""" + + entry = MockConfigEntry() + entry.add_to_hass(hass) + + with patch( + "ollama.AsyncClient.chat", + return_value=stream_generator( + [ + {"message": {"role": "assistant", "content": "test "}}, + { + "message": {"role": "assistant", "content": "response"}, + "done": True, + "done_reason": "stop", + }, + ], + ), + ) as mock_chat: + result = await conversation.async_converse( + hass, + "test message", + None, + Context(), + agent_id=mock_config_entry.entry_id, + ) + + assert mock_chat.call_count == 1 + args = mock_chat.call_args.kwargs + prompt = args["messages"][0]["content"] + + assert args["model"] == "test model" + assert args["messages"] == [ + Message(role="system", content=prompt), + Message(role="user", content="test message"), + ] + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, ( + result + ) + assert result.response.speech["plain"]["speech"] == "test response" + + async def test_template_variables( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: @@ -103,7 +161,9 @@ async def test_template_variables( patch("ollama.AsyncClient.list"), patch( "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, + return_value=stream_generator( + {"message": {"role": "assistant", "content": "test response"}} + ), ) as mock_chat, patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), ): @@ -170,26 +230,30 @@ async def test_function_call( def completion_result(*args, messages, **kwargs): for message in messages: if message["role"] == "tool": - return { - "message": { - "role": "assistant", - "content": "I have successfully called the function", - } - } - - return { - "message": { - "role": "assistant", - "tool_calls": [ + return stream_generator( { - "function": { - "name": "test_tool", - "arguments": tool_args, + "message": { + "role": "assistant", + "content": "I have successfully called the function", } } - ], + ) + + return stream_generator( + { + "message": { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "test_tool", + "arguments": tool_args, + } + } + ], + } } - } + ) with patch( "ollama.AsyncClient.chat", @@ -251,26 +315,30 @@ async def test_function_exception( def completion_result(*args, messages, **kwargs): for message in messages: if message["role"] == "tool": - return { - "message": { - "role": "assistant", - "content": "There was an error calling the function", - } - } - - return { - "message": { - "role": "assistant", - "tool_calls": [ + return stream_generator( { - "function": { - "name": "test_tool", - "arguments": {"param1": "test_value"}, + "message": { + "role": "assistant", + "content": "There was an error calling the function", } } - ], + ) + + return stream_generator( + { + "message": { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "test_tool", + "arguments": {"param1": "test_value"}, + } + } + ], + } } - } + ) with patch( "ollama.AsyncClient.chat", @@ -344,7 +412,9 @@ async def test_message_history_trimming( def response(*args, **kwargs) -> dict: nonlocal response_idx response_idx += 1 - return {"message": {"role": "assistant", "content": f"response {response_idx}"}} + return stream_generator( + {"message": {"role": "assistant", "content": f"response {response_idx}"}} + ) with patch( "ollama.AsyncClient.chat", @@ -438,11 +508,13 @@ async def test_message_history_unlimited( """Test that message history is not trimmed when max_history = 0.""" conversation_id = "1234" + def stream(*args, **kwargs) -> AsyncGenerator[dict]: + return stream_generator( + {"message": {"role": "assistant", "content": "test response"}} + ) + with ( - patch( - "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, - ) as mock_chat, + patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat, ): hass.config_entries.async_update_entry( mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} @@ -559,7 +631,9 @@ async def test_options( """Test that options are passed correctly to ollama client.""" with patch( "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, + return_value=stream_generator( + {"message": {"role": "assistant", "content": "test response"}} + ), ) as mock_chat: await conversation.async_converse( hass,