Update Ollama to use streaming API (#138177)

* Update ollama to use streaming APIs

* Remove unnecessary logging

* Update ollama to use streaming APIs

* Remove unnecessary logging

* Update homeassistant/components/ollama/conversation.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2025-02-09 21:05:41 -08:00 committed by GitHub
parent ae38f89728
commit 15223b3679
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 167 additions and 74 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
import json
import logging
from typing import Any, Literal
@ -123,7 +123,47 @@ def _convert_content(
role=MessageRole.SYSTEM.value,
content=chat_content.content,
)
raise ValueError(f"Unexpected content type: {type(chat_content)}")
raise TypeError(f"Unexpected content type: {type(chat_content)}")
async def _transform_stream(
result: AsyncGenerator[ollama.Message],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
An Ollama streaming response may come in chunks like this:
response: message=Message(role="assistant", content="Paris")
response: message=Message(role="assistant", content=".")
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
response: message=Message(role="assistant", tool_calls=[...])
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
This generator conforms to the chatlog delta stream expectations in that it
yields deltas, then the role only once the response is done.
"""
new_msg = True
async for response in result:
_LOGGER.debug("Received response: %s", response)
response_message = response["message"]
chunk: conversation.AssistantContentDeltaDict = {}
if new_msg:
new_msg = False
chunk["role"] = "assistant"
if (tool_calls := response_message.get("tool_calls")) is not None:
chunk["tool_calls"] = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
for tool_call in tool_calls
]
if (content := response_message.get("content")) is not None:
chunk["content"] = content
if response_message.get("done"):
new_msg = True
yield chunk
class OllamaConversationEntity(
@ -216,12 +256,12 @@ class OllamaConversationEntity(
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response = await client.chat(
response_generator = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
tools=tools,
stream=False,
stream=True,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
@ -232,46 +272,25 @@ class OllamaConversationEntity(
f"Sorry, I had a problem talking to the Ollama server: {err}"
) from err
response_message = response["message"]
content = response_message.get("content")
tool_calls = response_message.get("tool_calls")
message_history.messages.append(
ollama.Message(
role=response_message["role"],
content=content,
tool_calls=tool_calls,
)
)
tool_inputs = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
for tool_call in tool_calls or ()
]
message_history.messages.extend(
[
ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(tool_response.tool_result),
)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_inputs or None,
)
_convert_content(content)
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_stream(response_generator)
)
]
)
if not tool_calls:
if not chat_log.unresponded_tool_results:
break
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"])
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise TypeError(
f"Unexpected last message type: {type(chat_log.content[-1])}"
)
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult(
response=intent_response, conversation_id=chat_log.conversation_id
)

View File

@ -1,5 +1,6 @@
"""Tests for the Ollama integration."""
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
@ -25,6 +26,14 @@ def mock_ulid_tools():
yield
async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
"""Generate a response from the assistant."""
if not isinstance(response, list):
response = [response]
for msg in response:
yield msg
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat(
hass: HomeAssistant,
@ -42,7 +51,9 @@ async def test_chat(
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
return_value=stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
),
) as mock_chat:
result = await conversation.async_converse(
hass,
@ -81,6 +92,53 @@ async def test_chat(
assert "Current time is" in detail_event["data"]["messages"][0]["content"]
async def test_chat_stream(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test chat messages are assembled across streamed responses."""
entry = MockConfigEntry()
entry.add_to_hass(hass)
with patch(
"ollama.AsyncClient.chat",
return_value=stream_generator(
[
{"message": {"role": "assistant", "content": "test "}},
{
"message": {"role": "assistant", "content": "response"},
"done": True,
"done_reason": "stop",
},
],
),
) as mock_chat:
result = await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["messages"] == [
Message(role="system", content=prompt),
Message(role="user", content="test message"),
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result
)
assert result.response.speech["plain"]["speech"] == "test response"
async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
@ -103,7 +161,9 @@ async def test_template_variables(
patch("ollama.AsyncClient.list"),
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
return_value=stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
),
) as mock_chat,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
@ -170,26 +230,30 @@ async def test_function_call(
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["role"] == "tool":
return {
"message": {
"role": "assistant",
"content": "I have successfully called the function",
}
}
return {
"message": {
"role": "assistant",
"tool_calls": [
return stream_generator(
{
"function": {
"name": "test_tool",
"arguments": tool_args,
"message": {
"role": "assistant",
"content": "I have successfully called the function",
}
}
],
)
return stream_generator(
{
"message": {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": tool_args,
}
}
],
}
}
}
)
with patch(
"ollama.AsyncClient.chat",
@ -251,26 +315,30 @@ async def test_function_exception(
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["role"] == "tool":
return {
"message": {
"role": "assistant",
"content": "There was an error calling the function",
}
}
return {
"message": {
"role": "assistant",
"tool_calls": [
return stream_generator(
{
"function": {
"name": "test_tool",
"arguments": {"param1": "test_value"},
"message": {
"role": "assistant",
"content": "There was an error calling the function",
}
}
],
)
return stream_generator(
{
"message": {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": {"param1": "test_value"},
}
}
],
}
}
}
)
with patch(
"ollama.AsyncClient.chat",
@ -344,7 +412,9 @@ async def test_message_history_trimming(
def response(*args, **kwargs) -> dict:
nonlocal response_idx
response_idx += 1
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
return stream_generator(
{"message": {"role": "assistant", "content": f"response {response_idx}"}}
)
with patch(
"ollama.AsyncClient.chat",
@ -438,11 +508,13 @@ async def test_message_history_unlimited(
"""Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234"
def stream(*args, **kwargs) -> AsyncGenerator[dict]:
return stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
)
with (
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat,
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
):
hass.config_entries.async_update_entry(
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
@ -559,7 +631,9 @@ async def test_options(
"""Test that options are passed correctly to ollama client."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
return_value=stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
),
) as mock_chat:
await conversation.async_converse(
hass,